step3_text.py 19.4 KB
Newer Older
Song's avatar
Song committed
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Inference-only Jurassic model."""
4

Song's avatar
Song committed
5
from collections.abc import Iterable
6
from itertools import islice
7
from typing import Any
Song's avatar
Song committed
8
9
10
11
12
13

import torch
from torch import nn

from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig
14
15
16
17
from vllm.distributed import (
    get_pp_group,
    get_tensor_model_parallel_world_size,
)
Song's avatar
Song committed
18
19
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
20
from vllm.model_executor.layers.attention import Attention
Song's avatar
Song committed
21
22
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
23
24
25
26
27
28
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    MergedColumnParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
Song's avatar
Song committed
29
from vllm.model_executor.layers.logits_processor import LogitsProcessor
30
from vllm.model_executor.layers.quantization import QuantizationConfig
Song's avatar
Song committed
31
32
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
33
34
35
    ParallelLMHead,
    VocabParallelEmbedding,
)
Song's avatar
Song committed
36
37
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors
38
from vllm.transformers_utils.configs.step3_vl import Step3TextConfig
Song's avatar
Song committed
39
40

from .interfaces import SupportsPP
41
42
43
44
45
46
47
from .utils import (
    PPMissingLayer,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
Song's avatar
Song committed
48
49
50
51
52

logger = init_logger(__name__)


class FusedMoEBlock(nn.Module):
53
54
55
    def __init__(
        self,
        config: ModelConfig,
56
        quant_config: QuantizationConfig | None = None,
57
58
        prefix: str = "",
    ):
Song's avatar
Song committed
59
60
61
62
63
64
        super().__init__()
        self.tp_size = get_tensor_model_parallel_world_size()

        if self.tp_size > config.moe_num_experts:
            raise ValueError(
                f"Tensor parallel size {self.tp_size} is greater than "
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
                f"the number of experts {config.moe_num_experts}."
            )

        self.experts = FusedMoE(
            num_experts=config.moe_num_experts,
            top_k=config.moe_top_k,
            hidden_size=config.hidden_size,
            intermediate_size=config.moe_intermediate_size,
            renormalize=config.norm_expert_weight,
            quant_config=quant_config,
            prefix=f"{prefix}.experts",
        )
        self.gate = ReplicatedLinear(
            config.hidden_size,
            config.moe_num_experts,
            bias=False,
            quant_config=None,
            prefix=f"{prefix}.gate",
        )
Song's avatar
Song committed
84
85
86
87
88
89
90
91

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        orig_shape = hidden_states.shape
        hidden_dim = hidden_states.shape[-1]
        hidden_states = hidden_states.view(-1, hidden_dim)

        router_logits, _ = self.gate(hidden_states)

92
93
94
        final_hidden_states = self.experts(
            hidden_states=hidden_states, router_logits=router_logits
        )
Song's avatar
Song committed
95
96
97
98
99
100
101
102
103
104

        return final_hidden_states.view(orig_shape)


class Step3TextMLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
105
        quant_config: QuantizationConfig | None = None,
Song's avatar
Song committed
106
107
108
109
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
110
111
            hidden_size,
            [intermediate_size] * 2,
Song's avatar
Song committed
112
113
            bias=False,
            quant_config=quant_config,
114
115
116
117
118
119
120
121
122
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.down_proj",
        )
Song's avatar
Song committed
123
        if hidden_act != "silu":
124
125
126
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
Song's avatar
Song committed
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
        self.act_fn = SiluAndMul()
        self.hidden_size = hidden_size

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        gate_up, _ = self.gate_up_proj(hidden_states)
        intermediate_act = self.act_fn(gate_up)
        output, _ = self.down_proj(intermediate_act)
        return output


class Step3TextAttention(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        norm_eps: float,
144
        rope_parameters: dict[str, Any],
145
        share_q_dim: int | None = None,
Song's avatar
Song committed
146
147
        max_position_embedding: int = 8192,
        head_dim: int = 256,
148
149
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
Song's avatar
Song committed
150
151
152
153
154
155
156
157
158
159
160
        prefix: str = "",
    ):
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()

        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size

        if num_kv_heads != 1:
161
162
163
            raise ValueError(
                f"Step3TextAttention num_kv_heads must be 1, but got {num_kv_heads}."
            )
Song's avatar
Song committed
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
        self.num_kv_heads = num_kv_heads

        self.head_dim = head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.q_size = share_q_dim if share_q_dim else self.head_dim

        self.qkv_proj = ReplicatedLinear(
            hidden_size,
            self.q_size + self.kv_size * 2,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )

        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )
        self.inter_norm = RMSNorm(self.q_size, eps=norm_eps)
        self.wq = ColumnParallelLinear(
            self.q_size,
            self.head_dim * self.total_num_heads,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.wq",
        )
193
194
195
        self.rotary_emb = get_rope(
            self.head_dim,
            max_position=max_position_embedding,
196
            rope_parameters=rope_parameters,
197
        )
Song's avatar
Song committed
198
        scaling = self.head_dim**-0.5
199
200
201
202
203
204
205
206
207
208
209
210
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            scaling,
            self.num_kv_heads,
            cache_config=cache_config,
            prefix=f"{prefix}.attn",
        )

    def forward(
        self, positions: torch.Tensor, hidden_states: torch.Tensor
    ) -> torch.Tensor:
Song's avatar
Song committed
211
212
213
214
215
216
217
218
219
220
221
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q = self.inter_norm(q)
        q = self.wq(q)[0]
        q, k = self.rotary_emb(positions, q, k)
        attn_output = self.attn(q, k, v)
        residual, _ = self.o_proj(attn_output)
        return residual


class Step3TextDecoderLayer(nn.Module):
222
223
    def __init__(
        self,
224
        config: Step3TextConfig,
225
226
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
227
228
        prefix: str = "",
    ) -> None:
Song's avatar
Song committed
229
230
231
232
233
234
235
236
237
238
239
240
241
        super().__init__()
        self.hidden_size = config.hidden_size

        self.self_attn = Step3TextAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=1,
            cache_config=cache_config,
            quant_config=quant_config,
            norm_eps=config.rms_norm_eps,
            max_position_embedding=config.max_position_embedding,
            head_dim=config.head_dim,
            share_q_dim=config.share_q_dim,
242
            rope_parameters=config.rope_parameters,
243
244
            prefix=f"{prefix}.self_attn",
        )
Song's avatar
Song committed
245
246
247
248

        layer_idx = int(prefix.split("layers.")[1].split(".")[0])
        moe_layers_enum = getattr(config, "moe_layers_enum", None)
        if moe_layers_enum is not None:
249
            moe_layers_idx = [int(i) for i in moe_layers_enum.strip().split(",")]
Song's avatar
Song committed
250
251
252
253
254
        else:
            # Default to 1dense.
            moe_layers_idx = [i for i in range(1, config.num_hidden_layers)]

        if layer_idx in moe_layers_idx:
255
256
257
            self.moe = FusedMoEBlock(
                config=config, quant_config=quant_config, prefix=f"{prefix}.moe"
            )
Song's avatar
Song committed
258
259
260
261
262
            self.share_expert = Step3TextMLP(
                hidden_size=self.hidden_size,
                intermediate_size=config.share_expert_dim,
                hidden_act="silu",
                quant_config=quant_config,
263
264
                prefix=f"{prefix}.share_expert",
            )
Song's avatar
Song committed
265
266
            self.use_moe = True
        else:
267
268
269
270
271
272
273
            self.mlp = Step3TextMLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                hidden_act="silu",
                quant_config=quant_config,
                prefix=f"{prefix}.mlp",
            )
Song's avatar
Song committed
274
            self.use_moe = False
275
276
277
278
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
Song's avatar
Song committed
279
280

    def forward(
281
282
283
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
284
        residual: torch.Tensor | None,
Song's avatar
Song committed
285
286
287
288
289
    ) -> tuple[torch.Tensor, torch.Tensor]:
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
290
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
Song's avatar
Song committed
291
292
293
294
295
296

        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )

297
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
Song's avatar
Song committed
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318

        if self.use_moe:
            share_output = self.share_expert(hidden_states)
            moe_output = self.moe(hidden_states)
            hidden_states = share_output + moe_output
        else:
            hidden_states = self.mlp(hidden_states)

        return hidden_states, residual


@support_torch_compile
class Step3TextModel(nn.Module):
    def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
        super().__init__()
        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        self.vocab_size = config.vocab_size
        self.config = config

319
320
321
        if get_pp_group().is_first_rank or (
            config.tie_word_embeddings and get_pp_group().is_last_rank
        ):
Song's avatar
Song committed
322
323
324
325
326
327
328
329
330
            self.embed_tokens = VocabParallelEmbedding(
                self.vocab_size,
                config.hidden_size,
            )
        else:
            self.embed_tokens = PPMissingLayer()

        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
331
            lambda prefix: Step3TextDecoderLayer(
332
                config=config,
333
334
335
336
                cache_config=cache_config,
                quant_config=quant_config,
                prefix=prefix,
            ),
Song's avatar
Song committed
337
338
339
340
341
342
343
            prefix=f"{prefix}.layers",
        )
        if get_pp_group().is_last_rank:
            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        else:
            self.norm = PPMissingLayer()

344
345
346
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states"], config.hidden_size
        )
Song's avatar
Song committed
347

348
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
Song's avatar
Song committed
349
350
351
352
        return self.embed_tokens(input_ids)

    def forward(
        self,
353
        input_ids: torch.Tensor | None,
Song's avatar
Song committed
354
        positions: torch.Tensor,
355
356
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
Song's avatar
Song committed
357
358
359
360
361
    ) -> torch.Tensor:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
362
                hidden_states = self.embed_input_ids(input_ids)
Song's avatar
Song committed
363
364
365
366
367
368
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

369
        for layer in islice(self.layers, self.start_layer, self.end_layer):
Song's avatar
Song committed
370
371
372
            hidden_states, residual = layer(positions, hidden_states, residual)

        if not get_pp_group().is_last_rank:
373
374
375
376
377
378
            return IntermediateTensors(
                {
                    "hidden_states": hidden_states,
                    "residual": residual,
                }
            )
Song's avatar
Song committed
379
380
381
382
383
384
385
386
387
388
389
390
391
392

        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states


class Step3TextForCausalLM(nn.Module, SupportsPP):
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
    ):
        super().__init__()
        config = vllm_config.model_config.hf_config
393

Song's avatar
Song committed
394
395
396
397
398
399
400
        self.config = config
        self.vllm_config = vllm_config

        self.model = Step3TextModel(vllm_config=vllm_config, prefix=prefix)

        if get_pp_group().is_last_rank:
            self.lm_head = ParallelLMHead(
401
                config.vocab_size,
Song's avatar
Song committed
402
                config.hidden_size,
403
                prefix=maybe_prefix(prefix, "lm_head"),
Song's avatar
Song committed
404
            )
405
            self.logits_processor = LogitsProcessor(config.vocab_size)
Song's avatar
Song committed
406
407
408
409
        else:
            self.lm_head = PPMissingLayer()

        self.make_empty_intermediate_tensors = (
410
411
            self.model.make_empty_intermediate_tensors
        )
Song's avatar
Song committed
412

413
414
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
415

416
417
    def forward(
        self,
418
        input_ids: torch.Tensor | None,
419
        positions: torch.Tensor,
420
421
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
422
423
424
425
    ):
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
Song's avatar
Song committed
426
427
        return hidden_states

428
429
    def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head, hidden_states)
Song's avatar
Song committed
430
431
        return logits

432
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
Song's avatar
Song committed
433
434
        qkv_params_mapping = [
            # (param_name, shard_name, relative_start_idx, relative_end_idx)
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
            (
                ".qkv_proj",
                ".q_proj",
                0,
                self.config.share_q_dim
                / (self.config.share_q_dim + self.config.head_dim * 2),
            ),
            (
                ".qkv_proj",
                ".k_proj",
                self.config.share_q_dim
                / (self.config.share_q_dim + self.config.head_dim * 2),
                (self.config.share_q_dim + self.config.head_dim)
                / (self.config.share_q_dim + self.config.head_dim * 2),
            ),
            (
                ".qkv_proj",
                ".v_proj",
                (self.config.share_q_dim + self.config.head_dim)
                / (self.config.share_q_dim + self.config.head_dim * 2),
                (self.config.share_q_dim + self.config.head_dim * 2)
                / (self.config.share_q_dim + self.config.head_dim * 2),
            ),
Song's avatar
Song committed
458
459
460
461
462
463
464
465
        ]
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
466
467
468
        base_layer = (
            "base_layer." if any(".base_layer." in name for name in params_dict) else ""
        )
Song's avatar
Song committed
469
470

        expert_params_mapping = [
471
472
473
            (f".moe.experts.{base_layer}w13_weight", ".moe.gate_proj.weight", "w1"),
            (f".moe.experts.{base_layer}w13_weight", ".moe.up_proj.weight", "w3"),
            (f".moe.experts.{base_layer}w2_weight", ".moe.down_proj.weight", "w2"),
Song's avatar
Song committed
474
475
        ]

476
        disable_moe_stacked_params = [data[1] for data in expert_params_mapping]
Song's avatar
Song committed
477
478

        for name, loaded_weight in weights:
479
            for param_name, weight_name, shard_id in stacked_params_mapping:
Song's avatar
Song committed
480
481
                if weight_name not in name:
                    continue
482
483
484
485
                if any(
                    disable_moe_stacked_param in name
                    for disable_moe_stacked_param in disable_moe_stacked_params
                ):
Song's avatar
Song committed
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
                    continue
                name = name.replace(weight_name, param_name)
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                loaded_params.add(name)
                break
            else:
                for mapping in expert_params_mapping:
                    param_name, weight_name, shard_id = mapping
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
                    # Skip layers on other devices.
                    if is_pp_missing_parameter(name, self):
                        continue
                    # Skip loading extra bias for GPTQ models.
505
506
507
                    if (
                        name.endswith(".bias") or name.endswith("_bias")
                    ) and name not in params_dict:
Song's avatar
Song committed
508
509
510
511
512
                        continue
                    param = params_dict[name]
                    weight_loader = param.weight_loader
                    for expert_id in range(loaded_weight.shape[0]):
                        loaded_weight_expert = loaded_weight[expert_id]
513
514
515
516
517
518
519
                        weight_loader(
                            param,
                            loaded_weight_expert,
                            name,
                            shard_id=shard_id,
                            expert_id=expert_id,
                        )
Song's avatar
Song committed
520
521
522
                    loaded_params.add(name)
                    break
                else:
523
524
525
526
527
528
                    for (
                        param_name,
                        weight_name,
                        start_idx,
                        end_idx,
                    ) in qkv_params_mapping:
Song's avatar
Song committed
529
530
531
532
533
534
535
536
537
                        if weight_name not in name:
                            continue
                        name = name.replace(weight_name, param_name)
                        if is_pp_missing_parameter(name, self):
                            continue
                        param = params_dict[name]
                        dim = param.shape[param.output_dim]
                        begin_idx = int(start_idx * dim)
                        end_idx = int(end_idx * dim)
538
539
540
                        param_slice = param.narrow(
                            param.output_dim, begin_idx, end_idx - begin_idx
                        )
Song's avatar
Song committed
541
542
543
544
545
546
547
                        param_slice.copy_(loaded_weight)
                        loaded_params.add(name)
                        break
                    else:
                        if is_pp_missing_parameter(name, self):
                            continue
                        param = params_dict[name]
548
549
550
                        weight_loader = getattr(
                            param, "weight_loader", default_weight_loader
                        )
Song's avatar
Song committed
551
552
553
                        weight_loader(param, loaded_weight)
                        loaded_params.add(name)
        return loaded_params