qwen2_moe.py 22.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2MoE model compatible with HuggingFace weights."""
27

28
from collections.abc import Iterable
29
from itertools import islice
30
from typing import Any, Optional, Union
31
32
33
34

import torch
import torch.nn.functional as F
from torch import nn
35
from transformers import Qwen2MoeConfig
36

37
from vllm.attention import Attention
38
from vllm.compilation.decorators import support_torch_compile
39
from vllm.config import CacheConfig, VllmConfig
40
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
41
from vllm.logger import init_logger
42
from vllm.model_executor.layers.activation import SiluAndMul
43
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
44
from vllm.model_executor.layers.layernorm import RMSNorm
45
46
47
48
49
50
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
51
from vllm.model_executor.layers.logits_processor import LogitsProcessor
52
from vllm.model_executor.layers.quantization import QuantizationConfig
53
54
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
55
56
57
    ParallelLMHead,
    VocabParallelEmbedding,
)
58
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
59
from vllm.sequence import IntermediateTensors
60

61
from .interfaces import SupportsLoRA, SupportsPP
62
63
64
65
66
67
68
69
from .utils import (
    AutoWeightsLoader,
    extract_layer_index,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
70

71
72
logger = init_logger(__name__)

73
74
75
76
77
78
79

class Qwen2MoeMLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
80
        quant_config: Optional[QuantizationConfig] = None,
81
        reduce_results: bool = True,
82
        expert_gate: Optional[torch.nn.Linear] = None,
83
        prefix: str = "",
84
85
86
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
87
88
            hidden_size,
            [intermediate_size] * 2,
89
            bias=False,
90
            quant_config=quant_config,
91
92
93
94
95
96
97
98
99
100
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            reduce_results=reduce_results,
            prefix=f"{prefix}.down_proj",
        )
101
        if hidden_act != "silu":
102
103
104
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
105
        self.act_fn = SiluAndMul()
106
        self.expert_gate = expert_gate
107
108
109

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
110
111
112
113
114
115
116
        out = self.act_fn(gate_up)
        out, _ = self.down_proj(out)

        if self.expert_gate is not None:
            out = F.sigmoid(self.expert_gate(x)) * out

        return out
117
118
119
120
121


class Qwen2MoeSparseMoeBlock(nn.Module):
    def __init__(
        self,
122
        config: Qwen2MoeConfig,
123
        quant_config: Optional[QuantizationConfig] = None,
124
        prefix: str = "",
125
126
127
    ):
        super().__init__()
        self.tp_size = get_tensor_model_parallel_world_size()
128
129

        if self.tp_size > config.num_experts:
130
131
            raise ValueError(
                f"Tensor parallel size {self.tp_size} is greater than "
132
133
134
135
136
137
138
139
140
141
                f"the number of experts {config.num_experts}."
            )

        self.gate = ReplicatedLinear(
            config.hidden_size,
            config.num_experts,
            bias=False,
            quant_config=None,
            prefix=f"{prefix}.gate",
        )
142
143
144

        self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)

145
146
147
148
149
        if config.shared_expert_intermediate_size > 0:
            self.shared_expert = Qwen2MoeMLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.shared_expert_intermediate_size,
                hidden_act=config.hidden_act,
150
                quant_config=quant_config,
151
152
                reduce_results=False,
                expert_gate=self.shared_expert_gate,
153
                prefix=f"{prefix}.shared_expert",
154
155
156
            )
        else:
            self.shared_expert = None
157
158
159
160
161
162
163
164
165
166
167
168

        self.experts = SharedFusedMoE(
            shared_experts=self.shared_expert,
            num_experts=config.num_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            intermediate_size=config.moe_intermediate_size,
            reduce_results=False,
            renormalize=config.norm_topk_prob,
            quant_config=quant_config,
            prefix=f"{prefix}.experts",
        )
169
170

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
171
172
173
        # NOTE: hidden_states can have either 1D or 2D shape.
        orig_shape = hidden_states.shape
        hidden_dim = hidden_states.shape[-1]
174
175
176
177
        hidden_states = hidden_states.view(-1, hidden_dim)

        # router_logits: (num_tokens, n_experts)
        router_logits, _ = self.gate(hidden_states)
178
179
180
        final_hidden_states = self.experts(
            hidden_states=hidden_states, router_logits=router_logits
        )
181
182
        if self.shared_expert is not None:
            final_hidden_states = final_hidden_states[0] + final_hidden_states[1]
183
        if self.tp_size > 1:
184
            final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(  # noqa E501
185
186
                final_hidden_states
            )
187

188
        return final_hidden_states.view(orig_shape)
189
190
191
192
193
194
195
196
197


class Qwen2MoeAttention(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        rope_theta: float = 10000,
198
        rope_scaling: Optional[dict[str, Any]] = None,
199
        max_position_embeddings: int = 8192,
200
        cache_config: Optional[CacheConfig] = None,
201
        quant_config: Optional[QuantizationConfig] = None,
202
        prefix: str = "",
203
        dual_chunk_attention_config: Optional[dict[str, Any]] = None,
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings
227
        self.dual_chunk_attention_config = dual_chunk_attention_config
228

229
230
231
232
233
234
235
236
237
        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )
238

239
240
241
242
243
244
245
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )
246
247
248
249
250
251
252

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
253
            dual_chunk_attention_config=dual_chunk_attention_config,
254
        )
255
256
257
258
259
260
261
262
263
264
265
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
            **{
                "layer_idx": extract_layer_index(prefix),
                "dual_chunk_attention_config": dual_chunk_attention_config,
266
267
268
269
            }
            if dual_chunk_attention_config
            else {},
        )
270
271
272
273
274
275
276
277
278

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
279
        attn_output = self.attn(q, k, v)
280
281
282
283
284
285
286
        output, _ = self.o_proj(attn_output)
        return output


class Qwen2MoeDecoderLayer(nn.Module):
    def __init__(
        self,
287
        config: Qwen2MoeConfig,
288
        cache_config: Optional[CacheConfig] = None,
289
        quant_config: Optional[QuantizationConfig] = None,
290
        prefix: str = "",
291
292
293
294
295
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
296
297
298
299
        dual_chunk_attention_config = getattr(
            config, "dual_chunk_attention_config", None
        )
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
300
301
302
303
304
305
306
        self.self_attn = Qwen2MoeAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
307
            cache_config=cache_config,
308
            quant_config=quant_config,
309
            prefix=f"{prefix}.self_attn",
310
            dual_chunk_attention_config=dual_chunk_attention_config,
311
        )
312
313
314

        # Note: Qwen/Qwen2-57B-A14B-Instruct does not have
        # `mlp_only_layers` in the config.
315
        layer_idx = extract_layer_index(prefix)
316
317
318
        mlp_only_layers = (
            [] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers
        )
319
        if (layer_idx not in mlp_only_layers) and (
320
321
322
323
324
            config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0
        ):
            self.mlp = Qwen2MoeSparseMoeBlock(
                config=config, quant_config=quant_config, prefix=f"{prefix}.mlp"
            )
325
        else:
326
327
328
329
330
331
332
333
334
335
336
            self.mlp = Qwen2MoeMLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
                prefix=f"{prefix}.mlp",
            )
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
337
338
339
340
341
342
343
344
345
346
347
348

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
    ) -> torch.Tensor:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
349
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
350
351
352
353
354
355
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )

        # Fully Connected
356
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
357
358
359
360
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


361
@support_torch_compile
362
class Qwen2MoeModel(nn.Module):
363
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
364
        super().__init__()
365
366
367
368
369

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

370
        self.vocab_size = config.vocab_size
371
        self.config = config
372
373
374
375
376

        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
377
378
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
379
380
381
382
383
384
            lambda prefix: Qwen2MoeDecoderLayer(
                config=config,
                cache_config=cache_config,
                quant_config=quant_config,
                prefix=prefix,
            ),
385
386
            prefix=f"{prefix}.layers",
        )
387
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
388
389
390
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
391

392
393
394
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

395
396
397
398
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
399
        intermediate_tensors: Optional[IntermediateTensors] = None,
400
        inputs_embeds: Optional[torch.Tensor] = None,
401
    ) -> Union[torch.Tensor, IntermediateTensors]:
402
        if get_pp_group().is_first_rank:
403
404
405
406
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
407
408
409
410
411
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
412
        for layer in islice(self.layers, self.start_layer, self.end_layer):
413
            hidden_states, residual = layer(positions, hidden_states, residual)
414
        if not get_pp_group().is_last_rank:
415
416
417
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
418
419
420
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states

421
422
423
    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
424
        return SharedFusedMoE.make_expert_params_mapping(
425
426
427
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
428
429
            num_experts=self.config.num_experts,
        )
430

431
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
432
433
434
435
436
437
438
439
440
441
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]

        params_dict = dict(self.named_parameters())
442
        loaded_params: set[str] = set()
443
        expert_params_mapping = self.get_expert_mapping()
444
        for name, loaded_weight in weights:
445
            for param_name, weight_name, shard_id in stacked_params_mapping:
446
                # Skip non-stacked layers and experts (experts handled below).
447
448
                if weight_name not in name:
                    continue
449
450
451
452
453
454
455
456
                # We have mlp.experts[0].gate_proj in the checkpoint.
                # Since we handle the experts below in expert_params_mapping,
                # we need to skip here BEFORE we update the name, otherwise
                # name will be updated to mlp.experts[0].gate_up_proj, which
                # will then be updated below in expert_params_mapping
                # for mlp.experts[0].gate_gate_up_proj, which breaks load.
                if "mlp.experts" in name:
                    continue
457
458
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
459
460
461
                if (
                    name.endswith(".bias") or name.endswith("_bias")
                ) and name not in params_dict:
462
                    continue
463
464
465
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
466
467
468
                if name not in params_dict:
                    continue

469
470
471
472
473
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
474
                for mapping in expert_params_mapping:
475
476
477
478
                    param_name, weight_name, expert_id, shard_id = mapping
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
479

480
481
482
                    # Skip layers on other devices.
                    if is_pp_missing_parameter(name, self):
                        continue
483
                    # Skip loading extra bias for GPTQ models.
484
485
486
                    if (
                        name.endswith(".bias") or name.endswith("_bias")
                    ) and name not in params_dict:
487
                        continue
488
489
                    param = params_dict[name]
                    weight_loader = param.weight_loader
490
491
492
493
494
495
496
                    weight_loader(
                        param,
                        loaded_weight,
                        name,
                        shard_id=shard_id,
                        expert_id=expert_id,
                    )
497
498
499
                    break
                else:
                    # Skip loading extra bias for GPTQ models.
500
501
502
                    if (
                        name.endswith(".bias") or name.endswith("_bias")
                    ) and name not in params_dict:
503
                        continue
504
505
506
                    # Skip layers on other devices.
                    if is_pp_missing_parameter(name, self):
                        continue
507
508
509
                    # Remapping the name of FP8 kv-scale.
                    if name.endswith("kv_scale"):
                        remapped_kv_scale_name = name.replace(
510
511
                            ".kv_scale", ".attn.kv_scale"
                        )
512
                        if remapped_kv_scale_name not in params_dict:
513
                            logger.warning_once(
514
515
516
517
                                "Found kv_scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv_scale is not loaded.",  #  noqa: E501
                                name,
                                remapped_kv_scale_name,
                            )
518
519
520
                            continue
                        else:
                            name = remapped_kv_scale_name
521
                    param = params_dict[name]
522
523
524
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
525
                    weight_loader(param, loaded_weight)
526
527
            loaded_params.add(name)
        return loaded_params
528
529


530
class Qwen2MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
531
    fall_back_to_pt_during_load = False
532
533
534
535
536
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
537
        ]
538
    }
539
540
541
542
543
544
545

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        self.config = config
        self.quant_config = quant_config
546
547
548
549
550
551
552
553
554
555
556
557
        # Only perform the following mapping when Qwen2MoeMLP exists
        if (
            getattr(config, "mlp_only_layers", [])
            or config.shared_expert_intermediate_size > 0
        ):
            self.packed_modules_mapping["gate_up_proj"] = (
                [
                    "gate_proj",
                    "up_proj",
                ],
            )

558
559
560
561
562
563
564
565
566
        self.model = Qwen2MoeModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
        self.lm_head = ParallelLMHead(
            config.vocab_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "lm_head"),
        )
567
568
569
570
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.make_empty_intermediate_tensors = (
571
572
            self.model.make_empty_intermediate_tensors
        )
573
574
575
576
577
578
579
580
581
582
583

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
584
585
586
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
587
588
589
590
591
592
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
593
        logits = self.logits_processor(self.lm_head, hidden_states)
594
595
        return logits

596
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
597
        loader = AutoWeightsLoader(self)
598
        return loader.load_weights(weights)
599
600
601

    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        return self.model.get_expert_mapping()