qwen2_moe.py 22.3 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2MoE model compatible with HuggingFace weights."""
26
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
27
28
29
30
31
32

import torch
import torch.nn.functional as F
from torch import nn
from transformers import PretrainedConfig

33
from vllm.attention import Attention
34
from vllm.compilation.decorators import support_torch_compile
35
from vllm.config import CacheConfig, VllmConfig
36
37
from vllm.distributed import (get_pp_group,
                              get_tensor_model_parallel_world_size,
38
                              tensor_model_parallel_all_reduce)
39
from vllm.logger import init_logger
40
from vllm.model_executor.layers.activation import SiluAndMul
41
from vllm.model_executor.layers.fused_moe import FusedMoE
42
from vllm.model_executor.layers.layernorm import RMSNorm
43
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
44
45
46
47
                                               QKVParallelLinear,
                                               ReplicatedLinear,
                                               RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
48
from vllm.model_executor.layers.quantization import QuantizationConfig
49
from vllm.model_executor.layers.rotary_embedding import get_rope
Joe Runde's avatar
Joe Runde committed
50
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
51
52
from vllm.model_executor.layers.vocab_parallel_embedding import (
    ParallelLMHead, VocabParallelEmbedding)
53
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
54
from vllm.model_executor.sampling_metadata import SamplingMetadata
55
from vllm.sequence import IntermediateTensors
56

57
from .interfaces import SupportsPP
58
59
from .utils import (AutoWeightsLoader, extract_layer_index,
                    is_pp_missing_parameter,
60
61
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
62

63
64
logger = init_logger(__name__)

65
66
67
68
69
70
71
72

class Qwen2MoeMLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
73
        quant_config: Optional[QuantizationConfig] = None,
74
75
76
77
78
79
        reduce_results: bool = True,
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
            bias=False,
80
            quant_config=quant_config)
81
82
83
        self.down_proj = RowParallelLinear(intermediate_size,
                                           hidden_size,
                                           bias=False,
84
                                           quant_config=quant_config,
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
                                           reduce_results=reduce_results)
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class Qwen2MoeSparseMoeBlock(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
103
        quant_config: Optional[QuantizationConfig] = None,
104
        prefix: str = "",
105
106
107
    ):
        super().__init__()
        self.tp_size = get_tensor_model_parallel_world_size()
108
109

        if self.tp_size > config.num_experts:
110
111
            raise ValueError(
                f"Tensor parallel size {self.tp_size} is greater than "
112
113
114
115
116
117
118
119
                f"the number of experts {config.num_experts}.")

        self.experts = FusedMoE(num_experts=config.num_experts,
                                top_k=config.num_experts_per_tok,
                                hidden_size=config.hidden_size,
                                intermediate_size=config.moe_intermediate_size,
                                reduce_results=False,
                                renormalize=config.norm_topk_prob,
120
121
                                quant_config=quant_config,
                                prefix=f"{prefix}.experts")
122
123

        self.gate = ReplicatedLinear(config.hidden_size,
124
                                     config.num_experts,
125
                                     bias=False,
126
                                     quant_config=None)
127
128
129
130
131
        if config.shared_expert_intermediate_size > 0:
            self.shared_expert = Qwen2MoeMLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.shared_expert_intermediate_size,
                hidden_act=config.hidden_act,
132
                quant_config=quant_config,
133
134
135
136
137
138
139
140
141
                reduce_results=False,
            )
        else:
            self.shared_expert = None
        self.shared_expert_gate = torch.nn.Linear(config.hidden_size,
                                                  1,
                                                  bias=False)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
142
143
144
        # NOTE: hidden_states can have either 1D or 2D shape.
        orig_shape = hidden_states.shape
        hidden_dim = hidden_states.shape[-1]
145
146
147
148
149
150
151
152
153
154
        hidden_states = hidden_states.view(-1, hidden_dim)
        shared_output = None
        if self.shared_expert is not None:
            shared_output = self.shared_expert(hidden_states)
            if self.shared_expert_gate is not None:
                shared_output = F.sigmoid(
                    self.shared_expert_gate(hidden_states)) * shared_output

        # router_logits: (num_tokens, n_experts)
        router_logits, _ = self.gate(hidden_states)
155
156
        final_hidden_states = self.experts(hidden_states=hidden_states,
                                           router_logits=router_logits)
157
158
        if shared_output is not None:
            final_hidden_states = final_hidden_states + shared_output
159
160
161
        if self.tp_size > 1:
            final_hidden_states = tensor_model_parallel_all_reduce(
                final_hidden_states)
162

163
        return final_hidden_states.view(orig_shape)
164
165
166
167
168
169
170
171
172
173
174
175


class Qwen2MoeAttention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
        max_position_embeddings: int = 8192,
176
        cache_config: Optional[CacheConfig] = None,
177
        quant_config: Optional[QuantizationConfig] = None,
178
        prefix: str = "",
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=True,
209
            quant_config=quant_config,
210
211
212
213
214
215
        )

        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
216
            quant_config=quant_config,
217
218
219
220
221
222
223
224
225
226
227
228
        )

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
229
                              num_kv_heads=self.num_kv_heads,
230
                              cache_config=cache_config,
231
232
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
233
234
235
236
237
238
239
240
241

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
242
        attn_output = self.attn(q, k, v)
243
244
245
246
247
248
249
250
251
        output, _ = self.o_proj(attn_output)
        return output


class Qwen2MoeDecoderLayer(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
252
        cache_config: Optional[CacheConfig] = None,
253
        quant_config: Optional[QuantizationConfig] = None,
254
        prefix: str = "",
255
256
257
258
259
260
261
262
263
264
265
266
267
268
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
        self.self_attn = Qwen2MoeAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
269
            cache_config=cache_config,
270
            quant_config=quant_config,
271
            prefix=f"{prefix}.self_attn",
272
        )
273
274
275

        # Note: Qwen/Qwen2-57B-A14B-Instruct does not have
        # `mlp_only_layers` in the config.
276
        layer_idx = extract_layer_index(prefix)
277
278
279
        mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else
                           config.mlp_only_layers)
        if (layer_idx not in mlp_only_layers) and (
280
281
                config.num_experts > 0 and
            (layer_idx + 1) % config.decoder_sparse_step == 0):
282
            self.mlp = Qwen2MoeSparseMoeBlock(config=config,
283
284
                                              quant_config=quant_config,
                                              prefix=f"{prefix}.mlp")
285
286
287
288
289
        else:
            self.mlp = Qwen2MoeMLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                hidden_act=config.hidden_act,
290
                quant_config=quant_config,
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
            )
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
    ) -> torch.Tensor:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


322
@support_torch_compile
323
324
class Qwen2MoeModel(nn.Module):

325
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
326
        super().__init__()
327
328
329
330
331

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

332
        self.vocab_size = config.vocab_size
333
        self.config = config
334
335
336
337
338

        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
339
340
341
342
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
            lambda prefix: Qwen2MoeDecoderLayer(config=config,
                                                cache_config=cache_config,
343
344
                                                quant_config=quant_config,
                                                prefix=prefix),
345
346
            prefix=f"{prefix}.layers",
        )
347
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
348
349
350
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
351

352
353
354
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

355
356
357
358
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
359
        intermediate_tensors: Optional[IntermediateTensors] = None,
360
        inputs_embeds: Optional[torch.Tensor] = None,
361
    ) -> Union[torch.Tensor, IntermediateTensors]:
362
        if get_pp_group().is_first_rank:
363
364
365
366
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
367
368
369
370
371
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
372
373
        for layer in self.layers[self.start_layer:self.end_layer]:
            hidden_states, residual = layer(positions, hidden_states, residual)
374
375
376
377
378
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })
379
380
381
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states

382
383
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
384
385
386
387
388
389
390
391
392
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]

393
394
395
396
397
398
399
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
        expert_params_mapping = FusedMoE.make_expert_params_mapping(
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
            num_experts=self.config.num_experts)
400

401
        params_dict = dict(self.named_parameters())
402
        loaded_params: Set[str] = set()
403
        for name, loaded_weight in weights:
404
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
405
                # Skip non-stacked layers and experts (experts handled below).
406
407
                if weight_name not in name:
                    continue
408
409
410
411
412
413
414
415
                # We have mlp.experts[0].gate_proj in the checkpoint.
                # Since we handle the experts below in expert_params_mapping,
                # we need to skip here BEFORE we update the name, otherwise
                # name will be updated to mlp.experts[0].gate_up_proj, which
                # will then be updated below in expert_params_mapping
                # for mlp.experts[0].gate_gate_up_proj, which breaks load.
                if "mlp.experts" in name:
                    continue
416
417
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
418
419
                if ((name.endswith(".bias") or name.endswith("_bias"))
                        and name not in params_dict):
420
                    continue
421
422
423
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
424
425
426
                if name not in params_dict:
                    continue

427
428
429
430
431
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
432
433
434
435
436
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
437
438
439
                    # Skip layers on other devices.
                    if is_pp_missing_parameter(name, self):
                        continue
440
441
442
443
                    # Skip loading extra bias for GPTQ models.
                    if ((name.endswith(".bias") or name.endswith("_bias"))
                            and name not in params_dict):
                        continue
444
445
446
447
                    param = params_dict[name]
                    weight_loader = param.weight_loader
                    weight_loader(param,
                                  loaded_weight,
448
                                  name,
449
450
451
452
453
                                  shard_id=shard_id,
                                  expert_id=expert_id)
                    break
                else:
                    # Skip loading extra bias for GPTQ models.
454
455
                    if ((name.endswith(".bias") or name.endswith("_bias"))
                            and name not in params_dict):
456
                        continue
457
458
459
                    # Skip layers on other devices.
                    if is_pp_missing_parameter(name, self):
                        continue
460
461
462
463
464
                    # Remapping the name of FP8 kv-scale.
                    if name.endswith("kv_scale"):
                        remapped_kv_scale_name = name.replace(
                            ".kv_scale", ".attn.kv_scale")
                        if remapped_kv_scale_name not in params_dict:
465
                            logger.warning_once(
466
467
468
469
470
471
472
473
                                "Found kv scale in the checkpoint "
                                f"(e.g. {name}), but not found the expected "
                                f"name in the model "
                                f"(e.g. {remapped_kv_scale_name}). "
                                "kv-scale is not loaded.")
                            continue
                        else:
                            name = remapped_kv_scale_name
474
475
476
477
                    param = params_dict[name]
                    weight_loader = getattr(param, "weight_loader",
                                            default_weight_loader)
                    weight_loader(param, loaded_weight)
478
479
            loaded_params.add(name)
        return loaded_params
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541


class Qwen2MoeForCausalLM(nn.Module, SupportsPP):

    fall_back_to_pt_during_load = False

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        self.config = config
        self.quant_config = quant_config
        self.model = Qwen2MoeModel(vllm_config=vllm_config,
                                   prefix=maybe_prefix(prefix, "model"))
        self.lm_head = ParallelLMHead(config.vocab_size,
                                      config.hidden_size,
                                      quant_config=quant_config)
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = get_sampler()
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        hidden_states = self.model(input_ids, positions, intermediate_tensors,
                                   inputs_embeds)
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
        logits = self.logits_processor(self.lm_head, hidden_states,
                                       sampling_metadata)
        return logits

    def sample(
        self,
        logits: Optional[torch.Tensor],
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        next_tokens = self.sampler(logits, sampling_metadata)
        return next_tokens

    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=(["rotary_emb.inv_freq"]),
        )
        return loader.load_weights(weights)