qwen2_moe.py 21.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2MoE model compatible with HuggingFace weights."""
24
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
25
26
27
28
29
30
31

import torch
import torch.nn.functional as F
from torch import nn
from transformers import PretrainedConfig

from vllm.attention import Attention, AttentionMetadata
32
from vllm.compilation.decorators import support_torch_compile
33
from vllm.config import CacheConfig, VllmConfig
34
35
from vllm.distributed import (get_pp_group,
                              get_tensor_model_parallel_world_size,
36
                              tensor_model_parallel_all_reduce)
37
from vllm.model_executor.layers.activation import SiluAndMul
38
from vllm.model_executor.layers.fused_moe import FusedMoE
39
from vllm.model_executor.layers.layernorm import RMSNorm
40
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
41
42
43
44
                                               QKVParallelLinear,
                                               ReplicatedLinear,
                                               RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
45
from vllm.model_executor.layers.quantization import QuantizationConfig
46
from vllm.model_executor.layers.rotary_embedding import get_rope
Joe Runde's avatar
Joe Runde committed
47
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
48
49
from vllm.model_executor.layers.vocab_parallel_embedding import (
    ParallelLMHead, VocabParallelEmbedding)
50
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
51
from vllm.model_executor.sampling_metadata import SamplingMetadata
52
from vllm.sequence import IntermediateTensors
53
from vllm.utils import print_warning_once
54

55
56
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
57
58
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
59

60
61
62
63
64
65
66
67

class Qwen2MoeMLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
68
        quant_config: Optional[QuantizationConfig] = None,
69
70
71
72
73
74
        reduce_results: bool = True,
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
            bias=False,
75
            quant_config=quant_config)
76
77
78
        self.down_proj = RowParallelLinear(intermediate_size,
                                           hidden_size,
                                           bias=False,
79
                                           quant_config=quant_config,
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
                                           reduce_results=reduce_results)
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class Qwen2MoeSparseMoeBlock(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
98
        quant_config: Optional[QuantizationConfig] = None,
99
100
101
    ):
        super().__init__()
        self.tp_size = get_tensor_model_parallel_world_size()
102
103

        if self.tp_size > config.num_experts:
104
105
            raise ValueError(
                f"Tensor parallel size {self.tp_size} is greater than "
106
107
108
109
110
111
112
113
114
                f"the number of experts {config.num_experts}.")

        self.experts = FusedMoE(num_experts=config.num_experts,
                                top_k=config.num_experts_per_tok,
                                hidden_size=config.hidden_size,
                                intermediate_size=config.moe_intermediate_size,
                                reduce_results=False,
                                renormalize=config.norm_topk_prob,
                                quant_config=quant_config)
115
116

        self.gate = ReplicatedLinear(config.hidden_size,
117
                                     config.num_experts,
118
                                     bias=False,
119
                                     quant_config=None)
120
121
122
123
124
        if config.shared_expert_intermediate_size > 0:
            self.shared_expert = Qwen2MoeMLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.shared_expert_intermediate_size,
                hidden_act=config.hidden_act,
125
                quant_config=quant_config,
126
127
128
129
130
131
132
133
134
                reduce_results=False,
            )
        else:
            self.shared_expert = None
        self.shared_expert_gate = torch.nn.Linear(config.hidden_size,
                                                  1,
                                                  bias=False)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
135
136
137
        # NOTE: hidden_states can have either 1D or 2D shape.
        orig_shape = hidden_states.shape
        hidden_dim = hidden_states.shape[-1]
138
139
140
141
142
143
144
145
146
147
        hidden_states = hidden_states.view(-1, hidden_dim)
        shared_output = None
        if self.shared_expert is not None:
            shared_output = self.shared_expert(hidden_states)
            if self.shared_expert_gate is not None:
                shared_output = F.sigmoid(
                    self.shared_expert_gate(hidden_states)) * shared_output

        # router_logits: (num_tokens, n_experts)
        router_logits, _ = self.gate(hidden_states)
148
149
        final_hidden_states = self.experts(hidden_states=hidden_states,
                                           router_logits=router_logits)
150
151
        if shared_output is not None:
            final_hidden_states = final_hidden_states + shared_output
152
153
154
        if self.tp_size > 1:
            final_hidden_states = tensor_model_parallel_all_reduce(
                final_hidden_states)
155

156
        return final_hidden_states.view(orig_shape)
157
158
159
160
161
162
163
164
165
166
167
168


class Qwen2MoeAttention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
        max_position_embeddings: int = 8192,
169
        cache_config: Optional[CacheConfig] = None,
170
        quant_config: Optional[QuantizationConfig] = None,
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=True,
201
            quant_config=quant_config,
202
203
204
205
206
207
        )

        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
208
            quant_config=quant_config,
209
210
211
212
213
214
215
216
217
218
219
220
        )

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
221
                              num_kv_heads=self.num_kv_heads,
222
223
                              cache_config=cache_config,
                              quant_config=quant_config)
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
        output, _ = self.o_proj(attn_output)
        return output


class Qwen2MoeDecoderLayer(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
        layer_idx: int,
246
        cache_config: Optional[CacheConfig] = None,
247
        quant_config: Optional[QuantizationConfig] = None,
248
249
250
251
252
253
254
255
256
257
258
259
260
261
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
        self.self_attn = Qwen2MoeAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
262
            cache_config=cache_config,
263
            quant_config=quant_config,
264
        )
265
266
267
268
269
270

        # Note: Qwen/Qwen2-57B-A14B-Instruct does not have
        # `mlp_only_layers` in the config.
        mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else
                           config.mlp_only_layers)
        if (layer_idx not in mlp_only_layers) and (
271
272
                config.num_experts > 0 and
            (layer_idx + 1) % config.decoder_sparse_step == 0):
273
            self.mlp = Qwen2MoeSparseMoeBlock(config=config,
274
                                              quant_config=quant_config)
275
276
277
278
279
        else:
            self.mlp = Qwen2MoeMLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                hidden_act=config.hidden_act,
280
                quant_config=quant_config,
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
            )
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
        residual: Optional[torch.Tensor],
    ) -> torch.Tensor:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
            attn_metadata=attn_metadata,
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


316
@support_torch_compile
317
318
class Qwen2MoeModel(nn.Module):

319
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
320
        super().__init__()
321
322
323
324
325

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

326
327
328
329
330
331
332
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
333
334
335
336
337
338
339
340
341
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
            lambda prefix: Qwen2MoeDecoderLayer(config=config,
                                                layer_idx=int(
                                                    prefix.split(".")[-1]),
                                                cache_config=cache_config,
                                                quant_config=quant_config),
            prefix=f"{prefix}.layers",
        )
342
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
343
344
345
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
346
347
348
349
350
351
352

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
353
        intermediate_tensors: Optional[IntermediateTensors] = None,
354
    ) -> Union[torch.Tensor, IntermediateTensors]:
355
356
357
358
359
360
361
362
        if get_pp_group().is_first_rank:
            hidden_states = self.embed_tokens(input_ids)
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
        for i in range(self.start_layer, self.end_layer):
363
364
            layer = self.layers[i]
            hidden_states, residual = layer(positions, hidden_states,
365
366
367
368
369
370
371
                                            kv_caches[i - self.start_layer],
                                            attn_metadata, residual)
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })
372
373
374
375
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states


376
class Qwen2MoeForCausalLM(nn.Module, SupportsPP):
377

378
379
    fall_back_to_pt_during_load = False

380
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
381
        super().__init__()
382
383
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
384
        self.config = config
385
        self.quant_config = quant_config
386
387
        self.model = Qwen2MoeModel(vllm_config=vllm_config,
                                   prefix=maybe_prefix(prefix, "model"))
388
389
390
        self.lm_head = ParallelLMHead(config.vocab_size,
                                      config.hidden_size,
                                      quant_config=quant_config)
391
392
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
393
        self.logits_processor = LogitsProcessor(config.vocab_size)
Joe Runde's avatar
Joe Runde committed
394
        self.sampler = get_sampler()
395
396
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
397
398
399
400
401
402
403

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
404
        intermediate_tensors: Optional[IntermediateTensors] = None,
405
    ) -> Union[torch.Tensor, IntermediateTensors]:
406
        hidden_states = self.model(input_ids, positions, kv_caches,
407
                                   attn_metadata, intermediate_tensors)
408
409
        return hidden_states

410
411
412
413
414
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
415
        logits = self.logits_processor(self.lm_head, hidden_states,
416
417
418
419
420
421
422
423
424
425
426
                                       sampling_metadata)
        return logits

    def sample(
        self,
        logits: Optional[torch.Tensor],
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        next_tokens = self.sampler(logits, sampling_metadata)
        return next_tokens

427
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
428
429
430
431
432
433
434
435
436
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]

437
438
439
440
441
442
443
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
        expert_params_mapping = FusedMoE.make_expert_params_mapping(
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
            num_experts=self.config.num_experts)
444

445
        params_dict = dict(self.named_parameters())
446
        for name, loaded_weight in weights:
447
448
449
            if "rotary_emb.inv_freq" in name:
                continue
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
450
                # Skip non-stacked layers and experts (experts handled below).
451
452
                if weight_name not in name:
                    continue
453
454
455
456
457
458
459
460
                # We have mlp.experts[0].gate_proj in the checkpoint.
                # Since we handle the experts below in expert_params_mapping,
                # we need to skip here BEFORE we update the name, otherwise
                # name will be updated to mlp.experts[0].gate_up_proj, which
                # will then be updated below in expert_params_mapping
                # for mlp.experts[0].gate_gate_up_proj, which breaks load.
                if "mlp.experts" in name:
                    continue
461
462
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
463
464
                if ((name.endswith(".bias") or name.endswith("_bias"))
                        and name not in params_dict):
465
                    continue
466
467
468
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
469
470
471
                if name not in params_dict:
                    continue

472
473
474
475
476
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
477
478
479
480
481
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
482
483
484
                    # Skip layers on other devices.
                    if is_pp_missing_parameter(name, self):
                        continue
485
486
487
488
                    # Skip loading extra bias for GPTQ models.
                    if ((name.endswith(".bias") or name.endswith("_bias"))
                            and name not in params_dict):
                        continue
489
490
491
492
                    param = params_dict[name]
                    weight_loader = param.weight_loader
                    weight_loader(param,
                                  loaded_weight,
493
                                  name,
494
495
496
497
498
                                  shard_id=shard_id,
                                  expert_id=expert_id)
                    break
                else:
                    # Skip loading extra bias for GPTQ models.
499
500
                    if ((name.endswith(".bias") or name.endswith("_bias"))
                            and name not in params_dict):
501
                        continue
502
503
504
                    # Skip layers on other devices.
                    if is_pp_missing_parameter(name, self):
                        continue
505
506
507
508
509
510
511
512
513
514
515
516
517
518
                    # Remapping the name of FP8 kv-scale.
                    if name.endswith("kv_scale"):
                        remapped_kv_scale_name = name.replace(
                            ".kv_scale", ".attn.kv_scale")
                        if remapped_kv_scale_name not in params_dict:
                            print_warning_once(
                                "Found kv scale in the checkpoint "
                                f"(e.g. {name}), but not found the expected "
                                f"name in the model "
                                f"(e.g. {remapped_kv_scale_name}). "
                                "kv-scale is not loaded.")
                            continue
                        else:
                            name = remapped_kv_scale_name
519
520
521
522
                    param = params_dict[name]
                    weight_loader = getattr(param, "weight_loader",
                                            default_weight_loader)
                    weight_loader(param, loaded_weight)