qwen2_moe.py 22.5 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2MoE model compatible with HuggingFace weights."""
26
27
from collections.abc import Iterable
from typing import Any, Optional, Union
28
29
30
31
32
33

import torch
import torch.nn.functional as F
from torch import nn
from transformers import PretrainedConfig

34
from vllm.attention import Attention
35
from vllm.compilation.decorators import support_torch_compile
36
from vllm.config import CacheConfig, VllmConfig
37
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
38
from vllm.logger import init_logger
39
from vllm.model_executor.layers.activation import SiluAndMul
40
from vllm.model_executor.layers.fused_moe import FusedMoE
41
from vllm.model_executor.layers.layernorm import RMSNorm
42
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
43
44
45
46
                                               QKVParallelLinear,
                                               ReplicatedLinear,
                                               RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
47
from vllm.model_executor.layers.quantization import QuantizationConfig
48
49
50
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
    ParallelLMHead, VocabParallelEmbedding)
51
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
52
from vllm.model_executor.sampling_metadata import SamplingMetadata
53
from vllm.sequence import IntermediateTensors
54

55
from .interfaces import SupportsPP
56
57
from .utils import (AutoWeightsLoader, extract_layer_index,
                    is_pp_missing_parameter,
58
59
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
60

61
62
logger = init_logger(__name__)

63
64
65
66
67
68
69
70

class Qwen2MoeMLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
71
        quant_config: Optional[QuantizationConfig] = None,
72
73
74
75
76
77
        reduce_results: bool = True,
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
            bias=False,
78
            quant_config=quant_config)
79
80
81
        self.down_proj = RowParallelLinear(intermediate_size,
                                           hidden_size,
                                           bias=False,
82
                                           quant_config=quant_config,
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
                                           reduce_results=reduce_results)
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class Qwen2MoeSparseMoeBlock(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
101
        quant_config: Optional[QuantizationConfig] = None,
102
        prefix: str = "",
103
104
105
    ):
        super().__init__()
        self.tp_size = get_tensor_model_parallel_world_size()
106
107

        if self.tp_size > config.num_experts:
108
109
            raise ValueError(
                f"Tensor parallel size {self.tp_size} is greater than "
110
111
112
113
114
115
116
117
                f"the number of experts {config.num_experts}.")

        self.experts = FusedMoE(num_experts=config.num_experts,
                                top_k=config.num_experts_per_tok,
                                hidden_size=config.hidden_size,
                                intermediate_size=config.moe_intermediate_size,
                                reduce_results=False,
                                renormalize=config.norm_topk_prob,
118
119
                                quant_config=quant_config,
                                prefix=f"{prefix}.experts")
120
121

        self.gate = ReplicatedLinear(config.hidden_size,
122
                                     config.num_experts,
123
                                     bias=False,
124
                                     quant_config=None)
125
126
127
128
129
        if config.shared_expert_intermediate_size > 0:
            self.shared_expert = Qwen2MoeMLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.shared_expert_intermediate_size,
                hidden_act=config.hidden_act,
130
                quant_config=quant_config,
131
132
                reduce_results=self.experts.must_reduce_shared_expert_outputs(
                ),
133
134
135
136
137
138
139
140
            )
        else:
            self.shared_expert = None
        self.shared_expert_gate = torch.nn.Linear(config.hidden_size,
                                                  1,
                                                  bias=False)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
141
142
143
        # NOTE: hidden_states can have either 1D or 2D shape.
        orig_shape = hidden_states.shape
        hidden_dim = hidden_states.shape[-1]
144
145
146
147
148
149
150
151
152
153
        hidden_states = hidden_states.view(-1, hidden_dim)
        shared_output = None
        if self.shared_expert is not None:
            shared_output = self.shared_expert(hidden_states)
            if self.shared_expert_gate is not None:
                shared_output = F.sigmoid(
                    self.shared_expert_gate(hidden_states)) * shared_output

        # router_logits: (num_tokens, n_experts)
        router_logits, _ = self.gate(hidden_states)
154
155
        final_hidden_states = self.experts(hidden_states=hidden_states,
                                           router_logits=router_logits)
156
157
        if shared_output is not None:
            final_hidden_states = final_hidden_states + shared_output
158
        if self.tp_size > 1:
159
            final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(  # noqa E501
160
                final_hidden_states)
161

162
        return final_hidden_states.view(orig_shape)
163
164
165
166
167
168
169
170
171
172


class Qwen2MoeAttention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        rope_theta: float = 10000,
173
        rope_scaling: Optional[dict[str, Any]] = None,
174
        max_position_embeddings: int = 8192,
175
        cache_config: Optional[CacheConfig] = None,
176
        quant_config: Optional[QuantizationConfig] = None,
177
        prefix: str = "",
178
        dual_chunk_attention_config: Optional[dict[str, Any]] = None,
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings
202
        self.dual_chunk_attention_config = dual_chunk_attention_config
203
204
205
206
207
208
209

        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=True,
210
            quant_config=quant_config,
211
212
213
214
215
216
        )

        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
217
            quant_config=quant_config,
218
219
220
221
222
223
224
225
        )

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
226
            dual_chunk_attention_config=dual_chunk_attention_config,
227
        )
228
229
230
231
232
233
234
235
236
237
238
239
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
            **{
                "layer_idx": extract_layer_index(prefix),
                "dual_chunk_attention_config": dual_chunk_attention_config,
            } if dual_chunk_attention_config else {})
240
241
242
243
244
245
246
247
248

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
249
        attn_output = self.attn(q, k, v)
250
251
252
253
254
255
256
257
258
        output, _ = self.o_proj(attn_output)
        return output


class Qwen2MoeDecoderLayer(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
259
        cache_config: Optional[CacheConfig] = None,
260
        quant_config: Optional[QuantizationConfig] = None,
261
        prefix: str = "",
262
263
264
265
266
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
267
268
269
        dual_chunk_attention_config = getattr(config,
                                              "dual_chunk_attention_config",
                                              None)
270
271
272
273
274
275
276
277
278
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
        self.self_attn = Qwen2MoeAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
279
            cache_config=cache_config,
280
            quant_config=quant_config,
281
            prefix=f"{prefix}.self_attn",
282
            dual_chunk_attention_config=dual_chunk_attention_config,
283
        )
284
285
286

        # Note: Qwen/Qwen2-57B-A14B-Instruct does not have
        # `mlp_only_layers` in the config.
287
        layer_idx = extract_layer_index(prefix)
288
289
290
        mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else
                           config.mlp_only_layers)
        if (layer_idx not in mlp_only_layers) and (
291
292
                config.num_experts > 0 and
            (layer_idx + 1) % config.decoder_sparse_step == 0):
293
            self.mlp = Qwen2MoeSparseMoeBlock(config=config,
294
295
                                              quant_config=quant_config,
                                              prefix=f"{prefix}.mlp")
296
297
298
299
300
        else:
            self.mlp = Qwen2MoeMLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                hidden_act=config.hidden_act,
301
                quant_config=quant_config,
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
            )
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
    ) -> torch.Tensor:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


333
@support_torch_compile
334
335
class Qwen2MoeModel(nn.Module):

336
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
337
        super().__init__()
338
339
340
341
342

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

343
        self.vocab_size = config.vocab_size
344
        self.config = config
345
346
347
348
349

        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
350
351
352
353
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
            lambda prefix: Qwen2MoeDecoderLayer(config=config,
                                                cache_config=cache_config,
354
355
                                                quant_config=quant_config,
                                                prefix=prefix),
356
357
            prefix=f"{prefix}.layers",
        )
358
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
359
360
361
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
362

363
364
365
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

366
367
368
369
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
370
        intermediate_tensors: Optional[IntermediateTensors] = None,
371
        inputs_embeds: Optional[torch.Tensor] = None,
372
    ) -> Union[torch.Tensor, IntermediateTensors]:
373
        if get_pp_group().is_first_rank:
374
375
376
377
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
378
379
380
381
382
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
383
384
        for layer in self.layers[self.start_layer:self.end_layer]:
            hidden_states, residual = layer(positions, hidden_states, residual)
385
386
387
388
389
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })
390
391
392
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states

393
394
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
395
396
397
398
399
400
401
402
403
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]

404
405
406
407
408
409
410
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
        expert_params_mapping = FusedMoE.make_expert_params_mapping(
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
            num_experts=self.config.num_experts)
411

412
        params_dict = dict(self.named_parameters())
413
        loaded_params: set[str] = set()
414
        for name, loaded_weight in weights:
415
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
416
                # Skip non-stacked layers and experts (experts handled below).
417
418
                if weight_name not in name:
                    continue
419
420
421
422
423
424
425
426
                # We have mlp.experts[0].gate_proj in the checkpoint.
                # Since we handle the experts below in expert_params_mapping,
                # we need to skip here BEFORE we update the name, otherwise
                # name will be updated to mlp.experts[0].gate_up_proj, which
                # will then be updated below in expert_params_mapping
                # for mlp.experts[0].gate_gate_up_proj, which breaks load.
                if "mlp.experts" in name:
                    continue
427
428
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
429
430
                if ((name.endswith(".bias") or name.endswith("_bias"))
                        and name not in params_dict):
431
                    continue
432
433
434
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
435
436
437
                if name not in params_dict:
                    continue

438
439
440
441
442
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
443
444
445
446
447
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
448
449
450
                    # Skip layers on other devices.
                    if is_pp_missing_parameter(name, self):
                        continue
451
452
453
454
                    # Skip loading extra bias for GPTQ models.
                    if ((name.endswith(".bias") or name.endswith("_bias"))
                            and name not in params_dict):
                        continue
455
456
457
458
                    param = params_dict[name]
                    weight_loader = param.weight_loader
                    weight_loader(param,
                                  loaded_weight,
459
                                  name,
460
461
462
463
464
                                  shard_id=shard_id,
                                  expert_id=expert_id)
                    break
                else:
                    # Skip loading extra bias for GPTQ models.
465
466
                    if ((name.endswith(".bias") or name.endswith("_bias"))
                            and name not in params_dict):
467
                        continue
468
469
470
                    # Skip layers on other devices.
                    if is_pp_missing_parameter(name, self):
                        continue
471
472
473
474
475
                    # Remapping the name of FP8 kv-scale.
                    if name.endswith("kv_scale"):
                        remapped_kv_scale_name = name.replace(
                            ".kv_scale", ".attn.kv_scale")
                        if remapped_kv_scale_name not in params_dict:
476
                            logger.warning_once(
477
478
479
480
                                "Found kv_scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv_scale is not loaded.",  #  noqa: E501
                                name,
                                remapped_kv_scale_name,
                            )
481
482
483
                            continue
                        else:
                            name = remapped_kv_scale_name
484
485
486
487
                    param = params_dict[name]
                    weight_loader = getattr(param, "weight_loader",
                                            default_weight_loader)
                    weight_loader(param, loaded_weight)
488
489
            loaded_params.add(name)
        return loaded_params
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535


class Qwen2MoeForCausalLM(nn.Module, SupportsPP):

    fall_back_to_pt_during_load = False

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        self.config = config
        self.quant_config = quant_config
        self.model = Qwen2MoeModel(vllm_config=vllm_config,
                                   prefix=maybe_prefix(prefix, "model"))
        self.lm_head = ParallelLMHead(config.vocab_size,
                                      config.hidden_size,
                                      quant_config=quant_config)
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        hidden_states = self.model(input_ids, positions, intermediate_tensors,
                                   inputs_embeds)
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
        logits = self.logits_processor(self.lm_head, hidden_states,
                                       sampling_metadata)
        return logits

536
537
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
538
        loader = AutoWeightsLoader(self)
539
        return loader.load_weights(weights)