"vscode:/vscode.git/clone" did not exist on "e69a92a1cea23b36803caac2d251d906789eed1d"
phi3_small.py 18.4 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import math
4
from typing import Iterable, List, Optional, Set, Tuple, Union
5
6
7
8
9
10

import torch
from torch import nn
from transformers.configuration_utils import PretrainedConfig

from vllm.attention import Attention, AttentionMetadata
11
from vllm.config import CacheConfig, VllmConfig
12
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
13
14
15
16
17
                              get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
18
from vllm.model_executor.layers.quantization import QuantizationConfig
19
from vllm.model_executor.layers.rotary_embedding import get_rope
Joe Runde's avatar
Joe Runde committed
20
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
21
22
23
24
from vllm.model_executor.layers.vocab_parallel_embedding import (
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
25
from vllm.platforms import current_platform
26
from vllm.sequence import IntermediateTensors
27

28
29
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
30
31
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
32

33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59

def load_column_parallel_weight(param: torch.nn.Parameter,
                                loaded_weight: torch.Tensor):
    tp = get_tensor_model_parallel_world_size()
    rk = get_tensor_model_parallel_rank()
    assert param.size(0) * tp == loaded_weight.size(0)
    s = rk * param.size(0)
    e = (rk + 1) * param.size(0)
    loaded_weight = loaded_weight[s:e]
    assert param.shape == loaded_weight.shape
    param.data.copy_(loaded_weight)


class HeadMajorQKVParallelLinear(QKVParallelLinear):

    def weight_loader(self, param: torch.nn.Parameter,
                      loaded_weight: torch.Tensor):
        return load_column_parallel_weight(param, loaded_weight)


class HeadMajorColumnParallelLinear(MergedColumnParallelLinear):

    def weight_loader(self, param: torch.nn.Parameter,
                      loaded_weight: torch.Tensor):
        return load_column_parallel_weight(param, loaded_weight)


60
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
61
62
63
64
def quick_gelu(x):
    return x * torch.sigmoid(1.702 * x)


65
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def gegelu(input, limit: Optional[float] = None):
    a_gelu, a_linear = input[..., ::2], input[..., 1::2]
    if limit is not None:
        a_gelu = torch.where(torch.isinf(a_gelu), a_gelu,
                             a_gelu.clamp(min=None, max=limit))
        a_linear = torch.where(
            torch.isinf(a_linear),
            a_linear,
            a_linear.clamp(min=-limit, max=limit),
        )
    out_gelu = quick_gelu(a_gelu)
    return out_gelu * (a_linear + 1)


class Phi3SmallMLP(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
    ) -> None:
        super().__init__()
        self.config = config
        assert (self.config.hidden_act == "gegelu"
                ), "Only `gegelu` is supported for the 4.7 series of models .."
        self.hidden_size = config.hidden_size
        self.gegelu_limit = config.gegelu_limit
        self.intermediate_size = config.intermediate_size

        self.up_proj = HeadMajorColumnParallelLinear(
            self.hidden_size,
            2 * [self.intermediate_size],
            bias=True,
            quant_config=quant_config,
        )
        self.down_proj = RowParallelLinear(
            self.intermediate_size,
            self.hidden_size,
            bias=True,
            quant_config=quant_config,
        )

    def forward(self, x):
        gate_up, _ = self.up_proj(x)
        x = gegelu(gate_up)
        x, _ = self.down_proj(x)
        return x


class Phi3SmallSelfAttention(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
        layer_idx: int,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
123
        prefix: str = "",
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
    ) -> None:
        super().__init__()
        self.layer_idx = layer_idx
        self.config = config
        self.sparse_block_size = config.blocksparse_block_size
        self.homo_heads = config.blocksparse_homo_head_pattern
        self.local_blocks = config.blocksparse_num_local_blocks
        self.vert_stride = config.blocksparse_vert_stride

        assert (config.blocksparse_block_size ==
                config.blocksparse_triton_kernel_block_size)

        self.hidden_size = config.hidden_size
        # Number of Query Heads
        self.num_heads = config.num_attention_heads

        self.head_dim = self.hidden_size // self.num_heads
        self.tp_size = get_tensor_model_parallel_world_size()
        # Number of total Key Value Heads before tensor parallel
        self.num_key_value_heads = config.num_key_value_heads
        self.num_q_per_kv = self.num_heads // self.num_key_value_heads
        if self.tp_size > 1:
            assert self.num_key_value_heads % self.tp_size == 0
        self.num_kv_heads_per_partion = max(
            1, self.num_key_value_heads // self.tp_size)
        self.num_heads_per_partition = self.num_heads // self.tp_size

        self.max_position_embeddings = config.max_position_embeddings
        self.rope_embedding_base = config.rope_embedding_base
        self.rope_position_scale = config.rope_position_scale
        self.is_causal = True

        norm_factor = None
        if config.mup_use_scaling:
            norm_factor = self.head_dim / config.mup_attn_multiplier
        else:
            norm_factor = math.sqrt(self.head_dim)
        self.scale = 1 / norm_factor

        self.query_key_value = HeadMajorQKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            self.num_heads,
            self.num_key_value_heads,
            bias=True,
            quant_config=quant_config,
        )

        self.dense = RowParallelLinear(self.hidden_size,
                                       self.hidden_size,
                                       bias=True,
                                       quant_config=quant_config)

        if getattr(self.config, "rope_scaling", None) is not None:
            rope_scaling = self.config.rope_scaling
            for key in rope_scaling:
                if isinstance(rope_scaling[key], list):
                    rope_scaling[key] = tuple(rope_scaling[key])

            if "factor" not in rope_scaling:
                rope_scaling["factor"] = self.rope_position_scale
        else:
            rope_scaling = {
187
                "rope_type": "linear",
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
                "factor": self.rope_position_scale,
            }

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=self.max_position_embeddings,
            base=self.rope_embedding_base,
            rope_scaling=rope_scaling,
        )

        # blocksparse params
        self.blocksparse_block_size = config.blocksparse_block_size
        self.blocksparse_num_local_blocks = config.blocksparse_num_local_blocks
        self.blocksparse_vert_stride = config.blocksparse_vert_stride

        use_dense_attn = (getattr(self.config,
                                  "dense_attention_every_n_layers", None)
                          and (self.layer_idx + 1) %
                          self.config.dense_attention_every_n_layers == 0)

        bs_params = None
        if not use_dense_attn:
            bs_params = {
                'max_seqlen': self.max_position_embeddings,
                'num_heads': self.num_heads_per_partition,
                "num_kv_heads": self.num_kv_heads_per_partion,
                "block_size": self.sparse_block_size,
                "local_blocks": self.local_blocks,
                "vert_stride": self.vert_stride,
                "homo_head": self.homo_heads
            }

221
222
223
224
225
226
227
228
        self.attn = Attention(self.num_heads_per_partition,
                              self.head_dim,
                              self.scale,
                              num_kv_heads=self.num_kv_heads_per_partion,
                              cache_config=cache_config,
                              quant_config=quant_config,
                              blocksparse_params=bs_params,
                              prefix=f"{prefix}.attn")
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
               Optional[Tuple[torch.Tensor]]]:
        qkv, _ = self.query_key_value(hidden_states)

        qkv = qkv.view(qkv.shape[:-1] +
                       (-1, (self.num_q_per_kv + 2), self.head_dim))
        q, k, v = qkv.split([self.num_q_per_kv, 1, 1], dim=-2)

        # NOTE: this is required by RotaryEmbed, which indeed does not have to
        # TODO: allow 3D QK for rotary forward
        q = q.reshape(-1, self.head_dim * self.num_heads_per_partition)
        k = k.reshape(-1, self.head_dim * self.num_kv_heads_per_partion)
        v = v.reshape(-1, self.head_dim * self.num_kv_heads_per_partion)

        q, k = self.rotary_emb(positions, q, k)
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata=attn_metadata)
        output, _ = self.dense(attn_output)

        return output


class Phi3SmallDecoderLayer(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
        layer_idx: int,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
265
        prefix: str = "",
266
267
268
269
270
271
    ):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.self_attn = Phi3SmallSelfAttention(config,
                                                layer_idx,
                                                cache_config=cache_config,
272
273
                                                quant_config=quant_config,
                                                prefix=f"{prefix}.self_attn")
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
        self.mlp = Phi3SmallMLP(config, quant_config)

        self.input_layernorm = nn.LayerNorm(config.hidden_size,
                                            eps=config.layer_norm_epsilon)
        self.post_attention_layernorm = nn.LayerNorm(
            config.hidden_size, eps=config.layer_norm_epsilon)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)

        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
            attn_metadata=attn_metadata,
        )
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states
        return hidden_states


class Phi3SmallModel(nn.Module):

308
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
309
        super().__init__()
310
311
312
313
314

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

315
316
317
318
        self.config = config
        self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
                                                   config.hidden_size)
        self.mup_embedding_multiplier = config.mup_embedding_multiplier
319
320
321
322
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
            lambda prefix: Phi3SmallDecoderLayer(config,
                                                 int(prefix.split('.')[-1]),
323
324
325
                                                 cache_config,
                                                 quant_config,
                                                 prefix=prefix),
326
            prefix=f"{prefix}.layers")
327
328
329

        self.final_layernorm = nn.LayerNorm(config.hidden_size,
                                            eps=config.layer_norm_epsilon)
330
331
332
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.hidden_size))
333

334
335
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)
336
337
338
339
340
341

    def forward(
        self,
        input_ids: torch.LongTensor,
        positions: Optional[torch.LongTensor],
        kv_caches: List[torch.Tensor],
342
343
        attn_metadata: AttentionMetadata,
        intermediate_tensors: Optional[IntermediateTensors],
344
        inputs_embeds: Optional[torch.Tensor],
345
346
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
347
348
349
350
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
351
352
353
354
355
356
357
            if (self.mup_embedding_multiplier is not None
                    and self.mup_embedding_multiplier > 0.0):
                hidden_states = hidden_states * self.mup_embedding_multiplier
        else:
            assert intermediate_tensors
            hidden_states = intermediate_tensors["hidden_states"]
        for i in range(self.start_layer, self.end_layer):
358
359
360
361
            layer = self.layers[i]
            hidden_states = layer(
                positions,
                hidden_states,
362
                kv_caches[i - self.start_layer],
363
364
                attn_metadata,
            )
365
366
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
367
368
369
370
        hidden_states = self.final_layernorm(hidden_states)
        return hidden_states


371
class Phi3SmallForCausalLM(nn.Module, SupportsPP):
372
373
    _tied_weights_keys = ["lm_head.weight"]

374
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
375
        super().__init__()
376
377
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
378
379
        self.config = config
        self.quant_config = quant_config
380
381
        self.model = Phi3SmallModel(vllm_config=vllm_config,
                                    prefix=maybe_prefix(prefix, "model"))
382
383
384
385
386
387
388
        self.vocab_size = config.vocab_size
        self.mup_width_multiplier = config.mup_width_multiplier
        self.lm_head = ParallelLMHead(
            self.vocab_size,
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
            padding_size=DEFAULT_VOCAB_PADDING_SIZE,
389
            quant_config=quant_config,
390
        )
391
392
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
393
        self.logits_processor = LogitsProcessor(config.vocab_size)
Joe Runde's avatar
Joe Runde committed
394
        self.sampler = get_sampler()
395
396
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
397
398
399
400
401
402
403
404
405
406
407

        # tokens in tiktoken but not used
        if hasattr(config, 'dummy_token_indices'):
            device = self.lm_head.weight.device
            self.register_buffer('dummy_token_indices',
                                 torch.LongTensor(
                                     config.dummy_token_indices).to(device),
                                 persistent=False)
        else:
            self.dummy_token_indices = None

408
409
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425

    def set_input_embeddings(self, value):
        self.model.embed_tokens = value

    def get_output_embeddings(self):
        return self.lm_head

    def set_output_embeddings(self, value):
        self.lm_head = value

    def set_decoder(self, decoder):
        self.model = decoder

    def get_decoder(self):
        return self.model

426
427
428
429
430
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
431
        logits = self.logits_processor(self.lm_head, hidden_states,
432
433
434
435
436
437
438
439
440
441
442
                                       sampling_metadata)
        if self.dummy_token_indices is not None and logits is not None:
            logits.index_fill_(-1, self.dummy_token_indices, -torch.inf)
        return logits

    def forward(
        self,
        input_ids: torch.LongTensor,
        positions: Optional[torch.LongTensor],
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
443
        intermediate_tensors: Optional[IntermediateTensors] = None,
444
        inputs_embeds: Optional[torch.Tensor] = None,
445
    ) -> Union[torch.Tensor, IntermediateTensors]:
446
447
448
449
450
        output_hidden_states = self.model(
            input_ids=input_ids,
            positions=positions,
            kv_caches=kv_caches,
            attn_metadata=attn_metadata,
451
            intermediate_tensors=intermediate_tensors,
452
            inputs_embeds=inputs_embeds,
453
454
455
456
457
458
459
460
461
462
463
464
465
466
        )
        output_hidden_states = output_hidden_states
        return output_hidden_states

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:

        next_tokens = self.sampler(logits / self.mup_width_multiplier,
                                   sampling_metadata)
        return next_tokens

467
468
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
469
470

        params_dict = dict(self.named_parameters())
471
        loaded_params: Set[str] = set()
472
473
474
475
476
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue
            if name.endswith(".bias") and name not in params_dict:
                continue
477
478
            if is_pp_missing_parameter(name, self):
                continue
479
480
            if "lm_head.weight" in name and self.config.tie_word_embeddings:
                continue
481
482
483
484
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)
485
486
            loaded_params.add(name)
        return loaded_params