qwen2_moe.py 22.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2MoE model compatible with HuggingFace weights."""
27

28
from collections.abc import Iterable
29
from itertools import islice
30
from typing import Any
31
32
33
34

import torch
import torch.nn.functional as F
from torch import nn
35
from transformers import Qwen2MoeConfig
36

37
from vllm.compilation.decorators import support_torch_compile
38
from vllm.config import CacheConfig, VllmConfig
39
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
40
from vllm.logger import init_logger
41
from vllm.model_executor.layers.activation import SiluAndMul
42
from vllm.model_executor.layers.attention import Attention
43
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
44
from vllm.model_executor.layers.layernorm import RMSNorm
45
46
47
48
49
50
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
51
from vllm.model_executor.layers.logits_processor import LogitsProcessor
52
from vllm.model_executor.layers.quantization import QuantizationConfig
53
54
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
55
56
57
    ParallelLMHead,
    VocabParallelEmbedding,
)
58
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
59
from vllm.sequence import IntermediateTensors
60

61
from .interfaces import SupportsLoRA, SupportsPP
62
63
64
65
66
67
68
69
from .utils import (
    AutoWeightsLoader,
    extract_layer_index,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
70

71
72
logger = init_logger(__name__)

73
74
75
76
77
78
79

class Qwen2MoeMLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
80
        quant_config: QuantizationConfig | None = None,
81
        reduce_results: bool = True,
82
        expert_gate: torch.nn.Linear | None = None,
83
        is_sequence_parallel: bool = False,
84
        prefix: str = "",
85
86
87
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
88
89
            hidden_size,
            [intermediate_size] * 2,
90
            bias=False,
91
            quant_config=quant_config,
92
            disable_tp=is_sequence_parallel,
93
94
95
96
97
98
99
100
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            reduce_results=reduce_results,
101
            disable_tp=is_sequence_parallel,
102
103
            prefix=f"{prefix}.down_proj",
        )
104
        if hidden_act != "silu":
105
106
107
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
108
        self.act_fn = SiluAndMul()
109
        self.expert_gate = expert_gate
110
111
112

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
113
114
115
116
        out = self.act_fn(gate_up)
        out, _ = self.down_proj(out)

        if self.expert_gate is not None:
117
            out = F.sigmoid(self.expert_gate(x)[0]) * out
118
119

        return out
120
121
122
123
124


class Qwen2MoeSparseMoeBlock(nn.Module):
    def __init__(
        self,
125
        config: Qwen2MoeConfig,
126
        quant_config: QuantizationConfig | None = None,
127
        prefix: str = "",
128
129
130
    ):
        super().__init__()
        self.tp_size = get_tensor_model_parallel_world_size()
131
132

        if self.tp_size > config.num_experts:
133
134
            raise ValueError(
                f"Tensor parallel size {self.tp_size} is greater than "
135
136
137
138
139
140
141
142
143
144
                f"the number of experts {config.num_experts}."
            )

        self.gate = ReplicatedLinear(
            config.hidden_size,
            config.num_experts,
            bias=False,
            quant_config=None,
            prefix=f"{prefix}.gate",
        )
145

146
147
148
149
150
151
152
        self.shared_expert_gate = ReplicatedLinear(
            config.hidden_size,
            1,
            bias=False,
            quant_config=None,
            prefix=f"{prefix}.shared_expert_gate",
        )
153

154
155
156
157
158
        if config.shared_expert_intermediate_size > 0:
            self.shared_expert = Qwen2MoeMLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.shared_expert_intermediate_size,
                hidden_act=config.hidden_act,
159
                quant_config=quant_config,
160
161
                reduce_results=False,
                expert_gate=self.shared_expert_gate,
162
                prefix=f"{prefix}.shared_expert",
163
164
165
            )
        else:
            self.shared_expert = None
166
167
168
169
170
171
172
173
174
175
176
177

        self.experts = SharedFusedMoE(
            shared_experts=self.shared_expert,
            num_experts=config.num_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            intermediate_size=config.moe_intermediate_size,
            reduce_results=False,
            renormalize=config.norm_topk_prob,
            quant_config=quant_config,
            prefix=f"{prefix}.experts",
        )
178
179

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
180
181
182
        # NOTE: hidden_states can have either 1D or 2D shape.
        orig_shape = hidden_states.shape
        hidden_dim = hidden_states.shape[-1]
183
184
185
186
        hidden_states = hidden_states.view(-1, hidden_dim)

        # router_logits: (num_tokens, n_experts)
        router_logits, _ = self.gate(hidden_states)
187
188
189
        final_hidden_states = self.experts(
            hidden_states=hidden_states, router_logits=router_logits
        )
190
191
        if self.shared_expert is not None:
            final_hidden_states = final_hidden_states[0] + final_hidden_states[1]
192
        if self.tp_size > 1:
193
            final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(  # noqa E501
194
195
                final_hidden_states
            )
196

197
        return final_hidden_states.view(orig_shape)
198
199
200
201
202
203
204
205


class Qwen2MoeAttention(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
206
        rope_parameters: dict[str, Any] | None = None,
207
        max_position_embeddings: int = 8192,
208
209
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
210
        prefix: str = "",
211
        dual_chunk_attention_config: dict[str, Any] | None = None,
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.max_position_embeddings = max_position_embeddings
234
        self.dual_chunk_attention_config = dual_chunk_attention_config
235

236
237
238
239
240
241
242
243
244
        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )
245

246
247
248
249
250
251
252
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )
253
254
255
256

        self.rotary_emb = get_rope(
            self.head_dim,
            max_position=max_position_embeddings,
257
            rope_parameters=rope_parameters,
258
            dual_chunk_attention_config=dual_chunk_attention_config,
259
        )
260
261
262
263
264
265
266
267
268
269
270
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
            **{
                "layer_idx": extract_layer_index(prefix),
                "dual_chunk_attention_config": dual_chunk_attention_config,
271
272
273
274
            }
            if dual_chunk_attention_config
            else {},
        )
275
276
277
278
279
280
281
282
283

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
284
        attn_output = self.attn(q, k, v)
285
286
287
288
289
290
291
        output, _ = self.o_proj(attn_output)
        return output


class Qwen2MoeDecoderLayer(nn.Module):
    def __init__(
        self,
292
        config: Qwen2MoeConfig,
293
294
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
295
        prefix: str = "",
296
297
298
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
299
300
301
302
        dual_chunk_attention_config = getattr(
            config, "dual_chunk_attention_config", None
        )
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
303
304
305
306
        self.self_attn = Qwen2MoeAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
307
            rope_parameters=config.rope_parameters,
308
            max_position_embeddings=max_position_embeddings,
309
            cache_config=cache_config,
310
            quant_config=quant_config,
311
            prefix=f"{prefix}.self_attn",
312
            dual_chunk_attention_config=dual_chunk_attention_config,
313
        )
314
315
316

        # Note: Qwen/Qwen2-57B-A14B-Instruct does not have
        # `mlp_only_layers` in the config.
317
        layer_idx = extract_layer_index(prefix)
318
319
320
        mlp_only_layers = (
            [] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers
        )
321
        if (layer_idx not in mlp_only_layers) and (
322
323
324
325
326
            config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0
        ):
            self.mlp = Qwen2MoeSparseMoeBlock(
                config=config, quant_config=quant_config, prefix=f"{prefix}.mlp"
            )
327
        else:
328
329
330
331
332
333
334
335
336
337
338
            self.mlp = Qwen2MoeMLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
                prefix=f"{prefix}.mlp",
            )
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
339
340
341
342
343

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
344
        residual: torch.Tensor | None,
345
346
347
348
349
350
    ) -> torch.Tensor:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
351
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
352
353
354
355
356
357
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )

        # Fully Connected
358
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
359
360
361
362
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


363
@support_torch_compile
364
class Qwen2MoeModel(nn.Module):
365
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
366
        super().__init__()
367
368
369
370
371

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

372
        self.vocab_size = config.vocab_size
373
        self.config = config
374
375
376
377

        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
378
379
            quant_config=quant_config,
            prefix=f"{prefix}.embed_tokens",
380
        )
381
382
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
383
384
385
386
387
388
            lambda prefix: Qwen2MoeDecoderLayer(
                config=config,
                cache_config=cache_config,
                quant_config=quant_config,
                prefix=prefix,
            ),
389
390
            prefix=f"{prefix}.layers",
        )
391
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
392
393
394
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
395

396
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
397
398
        return self.embed_tokens(input_ids)

399
400
    def forward(
        self,
401
        input_ids: torch.Tensor | None,
402
        positions: torch.Tensor,
403
404
405
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
406
        if get_pp_group().is_first_rank:
407
408
409
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
410
                hidden_states = self.embed_input_ids(input_ids)
411
412
413
414
415
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
416
        for layer in islice(self.layers, self.start_layer, self.end_layer):
417
            hidden_states, residual = layer(positions, hidden_states, residual)
418
        if not get_pp_group().is_last_rank:
419
420
421
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
422
423
424
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states

425
426
427
    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
428
        return SharedFusedMoE.make_expert_params_mapping(
429
            self,
430
431
432
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
433
434
            num_experts=self.config.num_experts,
        )
435

436
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
437
438
439
440
441
442
443
444
445
446
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]

        params_dict = dict(self.named_parameters())
447
        loaded_params: set[str] = set()
448
        expert_params_mapping = self.get_expert_mapping()
449
        for name, loaded_weight in weights:
450
            for param_name, weight_name, shard_id in stacked_params_mapping:
451
                # Skip non-stacked layers and experts (experts handled below).
452
453
                if weight_name not in name:
                    continue
454
455
456
457
458
459
460
461
                # We have mlp.experts[0].gate_proj in the checkpoint.
                # Since we handle the experts below in expert_params_mapping,
                # we need to skip here BEFORE we update the name, otherwise
                # name will be updated to mlp.experts[0].gate_up_proj, which
                # will then be updated below in expert_params_mapping
                # for mlp.experts[0].gate_gate_up_proj, which breaks load.
                if "mlp.experts" in name:
                    continue
462
463
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
464
465
466
                if (
                    name.endswith(".bias") or name.endswith("_bias")
                ) and name not in params_dict:
467
                    continue
468
469
470
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
471
472
473
                if name not in params_dict:
                    continue

474
475
476
477
478
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
479
                for mapping in expert_params_mapping:
480
481
482
483
                    param_name, weight_name, expert_id, shard_id = mapping
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
484

485
486
487
                    # Skip layers on other devices.
                    if is_pp_missing_parameter(name, self):
                        continue
488
                    # Skip loading extra bias for GPTQ models.
489
490
491
                    if (
                        name.endswith(".bias") or name.endswith("_bias")
                    ) and name not in params_dict:
492
                        continue
493
494
                    param = params_dict[name]
                    weight_loader = param.weight_loader
495
496
497
498
499
500
501
                    weight_loader(
                        param,
                        loaded_weight,
                        name,
                        shard_id=shard_id,
                        expert_id=expert_id,
                    )
502
503
504
                    break
                else:
                    # Skip loading extra bias for GPTQ models.
505
506
507
                    if (
                        name.endswith(".bias") or name.endswith("_bias")
                    ) and name not in params_dict:
508
                        continue
509
510
511
                    # Skip layers on other devices.
                    if is_pp_missing_parameter(name, self):
                        continue
512
513
514
                    # Remapping the name of FP8 kv-scale.
                    if name.endswith("kv_scale"):
                        remapped_kv_scale_name = name.replace(
515
516
                            ".kv_scale", ".attn.kv_scale"
                        )
517
                        if remapped_kv_scale_name not in params_dict:
518
                            logger.warning_once(
519
520
521
522
                                "Found kv_scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv_scale is not loaded.",  #  noqa: E501
                                name,
                                remapped_kv_scale_name,
                            )
523
524
525
                            continue
                        else:
                            name = remapped_kv_scale_name
526
527
528
529
530
531
                    # GGUF: make sure that shared_expert_gate is a 2D tensor.
                    if (
                        "mlp.shared_expert_gate" in name
                        and len(loaded_weight.shape) == 1
                    ):
                        loaded_weight = loaded_weight[None, :]
532
                    param = params_dict[name]
533
534
535
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
536
                    weight_loader(param, loaded_weight)
537
538
            loaded_params.add(name)
        return loaded_params
539
540


541
class Qwen2MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
542
    fall_back_to_pt_during_load = False
543
544
545
546
547
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
548
        ]
549
    }
550
551
552
553
554
555
556

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        self.config = config
        self.quant_config = quant_config
557
558
559
560
561
        # Only perform the following mapping when Qwen2MoeMLP exists
        if (
            getattr(config, "mlp_only_layers", [])
            or config.shared_expert_intermediate_size > 0
        ):
562
            self.packed_modules_mapping["gate_up_proj"] = ["gate_proj", "up_proj"]
563

564
565
566
567
568
569
570
571
572
        self.model = Qwen2MoeModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
        self.lm_head = ParallelLMHead(
            config.vocab_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "lm_head"),
        )
573
574
575
576
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.make_empty_intermediate_tensors = (
577
578
            self.model.make_empty_intermediate_tensors
        )
579

580
581
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
582
583
584

    def forward(
        self,
585
        input_ids: torch.Tensor | None,
586
        positions: torch.Tensor,
587
588
589
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
590
591
592
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
593
594
595
596
597
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
598
    ) -> torch.Tensor | None:
599
        logits = self.logits_processor(self.lm_head, hidden_states)
600
601
        return logits

602
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
603
        loader = AutoWeightsLoader(self)
604
        return loader.load_weights(weights)
605
606
607

    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        return self.model.get_expert_mapping()