mixtral.py 18.7 KB
Newer Older
Pierre Stock's avatar
Pierre Stock committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Mixtral model."""
24
from typing import Iterable, List, Optional, Tuple, Union
Pierre Stock's avatar
Pierre Stock committed
25
26
27

import torch
from torch import nn
28
from transformers import MixtralConfig
Pierre Stock's avatar
Pierre Stock committed
29

30
from vllm.attention import Attention, AttentionMetadata
31
from vllm.compilation.decorators import support_torch_compile
32
from vllm.config import CacheConfig, LoRAConfig
33
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
34
from vllm.model_executor.layers.fused_moe import FusedMoE
Pierre Stock's avatar
Pierre Stock committed
35
from vllm.model_executor.layers.layernorm import RMSNorm
36
from vllm.model_executor.layers.linear import (QKVParallelLinear,
Philipp Moritz's avatar
Philipp Moritz committed
37
                                               ReplicatedLinear,
Pierre Stock's avatar
Pierre Stock committed
38
                                               RowParallelLinear)
39
from vllm.model_executor.layers.logits_processor import LogitsProcessor
40
from vllm.model_executor.layers.quantization import QuantizationConfig
41
from vllm.model_executor.layers.rotary_embedding import get_rope
42
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
Pierre Stock's avatar
Pierre Stock committed
43
from vllm.model_executor.layers.vocab_parallel_embedding import (
44
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
45
46
from vllm.model_executor.model_loader.weight_utils import (
    default_weight_loader, maybe_remap_kv_scale_name)
Pierre Stock's avatar
Pierre Stock committed
47
from vllm.model_executor.sampling_metadata import SamplingMetadata
48
from vllm.sequence import IntermediateTensors
Pierre Stock's avatar
Pierre Stock committed
49

50
51
52
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
                    make_empty_intermediate_tensors_factory, make_layers)
53

Pierre Stock's avatar
Pierre Stock committed
54

Philipp Moritz's avatar
Philipp Moritz committed
55
56
57
58
59
60
61
62
class MixtralMoE(nn.Module):
    """A tensor-parallel MoE implementation for Mixtral that shards each expert
    across all ranks.

    Each expert's weights are sharded across all ranks and a fused MoE
    kernel is used for the forward pass, and finally we reduce the outputs
    across ranks.
    """
63

64
65
66
67
68
69
70
    def __init__(self,
                 num_experts: int,
                 top_k: int,
                 hidden_size: int,
                 intermediate_size: int,
                 params_dtype: Optional[torch.dtype] = None,
                 quant_config: Optional[QuantizationConfig] = None,
71
72
                 tp_size: Optional[int] = None,
                 prefix: str = ""):
73
        super().__init__()
Philipp Moritz's avatar
Philipp Moritz committed
74
        self.hidden_size = hidden_size
75

76
        # Gate always runs at half / full precision for now.
77

78
79
        self.gate = ReplicatedLinear(hidden_size,
                                     num_experts,
80
                                     bias=False,
81
                                     params_dtype=params_dtype,
82
83
                                     quant_config=None,
                                     prefix=f"{prefix}.gate")
84

85
86
87
88
89
90
91
92
        self.experts = FusedMoE(num_experts=num_experts,
                                top_k=top_k,
                                hidden_size=hidden_size,
                                intermediate_size=intermediate_size,
                                params_dtype=params_dtype,
                                reduce_results=True,
                                renormalize=True,
                                quant_config=quant_config,
93
94
                                tp_size=tp_size,
                                prefix=f"{prefix}.experts")
95

96
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
97
98
        # NOTE: hidden_states can have either 1D or 2D shape.
        orig_shape = hidden_states.shape
Philipp Moritz's avatar
Philipp Moritz committed
99
        hidden_states = hidden_states.view(-1, self.hidden_size)
100
        # router_logits: (num_tokens, n_experts)
101
        router_logits, _ = self.gate(hidden_states)
102
        final_hidden_states = self.experts(hidden_states, router_logits)
103
        return final_hidden_states.view(orig_shape)
Pierre Stock's avatar
Pierre Stock committed
104
105
106
107


class MixtralAttention(nn.Module):

108
109
110
111
112
113
114
115
116
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        max_position: int = 4096 * 32,
        rope_theta: float = 10000,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
117
        prefix: str = "",
118
    ) -> None:
Pierre Stock's avatar
Pierre Stock committed
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta

141
        self.qkv_proj = QKVParallelLinear(
Pierre Stock's avatar
Pierre Stock committed
142
143
144
145
146
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
147
            quant_config=quant_config,
148
            prefix=f"{prefix}.qkv_proj",
Pierre Stock's avatar
Pierre Stock committed
149
        )
150
        self.o_proj = RowParallelLinear(
Pierre Stock's avatar
Pierre Stock committed
151
152
153
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
154
            quant_config=quant_config,
155
            prefix=f"{prefix}.o_proj",
Pierre Stock's avatar
Pierre Stock committed
156
157
158
159
160
161
        )
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position,
            base=int(self.rope_theta),
162
            is_neox_style=True,
Pierre Stock's avatar
Pierre Stock committed
163
        )
164
165
166
167
168
169
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
                              num_kv_heads=self.num_kv_heads,
                              cache_config=cache_config,
                              quant_config=quant_config)
Pierre Stock's avatar
Pierre Stock committed
170
171
172
173
174

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
175
176
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Pierre Stock's avatar
Pierre Stock committed
177
    ) -> torch.Tensor:
178
        qkv, _ = self.qkv_proj(hidden_states)
Pierre Stock's avatar
Pierre Stock committed
179
180
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
181
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
182
        output, _ = self.o_proj(attn_output)
Pierre Stock's avatar
Pierre Stock committed
183
184
185
186
187
188
189
        return output


class MixtralDecoderLayer(nn.Module):

    def __init__(
        self,
190
        config: MixtralConfig,
191
        cache_config: Optional[CacheConfig] = None,
192
        quant_config: Optional[QuantizationConfig] = None,
193
        prefix: str = "",
Pierre Stock's avatar
Pierre Stock committed
194
195
196
197
198
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        # Requires transformers > 4.32.0
        rope_theta = getattr(config, "rope_theta", 10000)
199
        self.self_attn = MixtralAttention(
Pierre Stock's avatar
Pierre Stock committed
200
201
202
203
204
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            max_position=config.max_position_embeddings,
            num_kv_heads=config.num_key_value_heads,
            rope_theta=rope_theta,
205
            cache_config=cache_config,
206
207
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn")
Philipp Moritz's avatar
Philipp Moritz committed
208
209
210
211
        self.block_sparse_moe = MixtralMoE(
            num_experts=config.num_local_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
212
            intermediate_size=config.intermediate_size,
213
214
            quant_config=quant_config,
            prefix=f"{prefix}.block_sparse_moe")
215
216
217
218
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)
Pierre Stock's avatar
Pierre Stock committed
219
220
221
222

    def forward(
        self,
        positions: torch.Tensor,
223
        hidden_states: torch.Tensor,
224
225
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
226
        residual: Optional[torch.Tensor],
Pierre Stock's avatar
Pierre Stock committed
227
    ) -> torch.Tensor:
228
229
230
231
232
233
234
235
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
        hidden_states = self.self_attn(
Pierre Stock's avatar
Pierre Stock committed
236
            positions=positions,
237
            hidden_states=hidden_states,
Pierre Stock's avatar
Pierre Stock committed
238
            kv_cache=kv_cache,
239
            attn_metadata=attn_metadata,
Pierre Stock's avatar
Pierre Stock committed
240
241
        )

242
243
244
245
246
        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
        hidden_states = self.block_sparse_moe(hidden_states)
        return hidden_states, residual
Pierre Stock's avatar
Pierre Stock committed
247

248

249
@support_torch_compile
250
class MixtralModel(nn.Module):
Pierre Stock's avatar
Pierre Stock committed
251
252
253

    def __init__(
        self,
254
        config: MixtralConfig,
255
        cache_config: Optional[CacheConfig] = None,
256
        quant_config: Optional[QuantizationConfig] = None,
257
        lora_config: Optional[LoRAConfig] = None,
258
        prefix: str = "",
Pierre Stock's avatar
Pierre Stock committed
259
260
261
    ) -> None:
        super().__init__()
        self.padding_idx = config.pad_token_id
262
263
264
265
        lora_vocab = (lora_config.lora_extra_vocab_size *
                      (lora_config.max_loras or 1)) if lora_config else 0
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size
266
267

        self.embed_tokens = VocabParallelEmbedding(
268
            self.vocab_size,
Pierre Stock's avatar
Pierre Stock committed
269
            config.hidden_size,
270
            org_num_embeddings=config.vocab_size,
Pierre Stock's avatar
Pierre Stock committed
271
        )
272
273

        self.start_layer, self.end_layer, self.layers = make_layers(
274
275
276
277
278
            config.num_hidden_layers,
            lambda prefix: MixtralDecoderLayer(
                config, cache_config, quant_config=quant_config, prefix=prefix
            ),
            prefix=f"{prefix}.layers")
279

280
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
281
282
283
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
Pierre Stock's avatar
Pierre Stock committed
284
285
286
287
288

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
289
290
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
291
        intermediate_tensors: Optional[IntermediateTensors],
292
    ) -> Union[torch.Tensor, IntermediateTensors]:
293
294
295
296
297
298
299
300
        if get_pp_group().is_first_rank:
            hidden_states = self.embed_tokens(input_ids)
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
        for i in range(self.start_layer, self.end_layer):
Pierre Stock's avatar
Pierre Stock committed
301
            layer = self.layers[i]
302
            hidden_states, residual = layer(positions, hidden_states,
303
304
305
306
307
308
309
                                            kv_caches[i - self.start_layer],
                                            attn_metadata, residual)
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })
310
311
312
313
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states


314
class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
315
316
    fall_back_to_pt_during_load = False

Terry's avatar
Terry committed
317
318
319
320
321
322
323
324
325
326
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
    }

    # LoRA specific attributes
    supported_lora_modules = [
327
328
        "qkv_proj", "o_proj", "embed_tokens", "lm_head", "w1", "w2", "w3",
        "gate"
Terry's avatar
Terry committed
329
330
331
332
333
334
    ]
    embedding_modules = {
        "embed_tokens": "input_embeddings",
        "lm_head": "output_embeddings",
    }
    embedding_padding_modules = ["lm_head"]
335
336
337
338

    def __init__(
        self,
        config: MixtralConfig,
339
        cache_config: Optional[CacheConfig] = None,
340
        quant_config: Optional[QuantizationConfig] = None,
Terry's avatar
Terry committed
341
        lora_config: Optional[LoRAConfig] = None,
342
343
    ) -> None:
        super().__init__()
344

345
        self.config = config
346
347
        self.lora_config = lora_config

348
        self.model = MixtralModel(config,
349
                                  cache_config,
350
                                  quant_config,
351
352
                                  lora_config=lora_config,
                                  prefix="model")
Terry's avatar
Terry committed
353
354
355
356
357
358
359
360
361
362
363
        self.unpadded_vocab_size = config.vocab_size
        if lora_config:
            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
        self.lm_head = ParallelLMHead(
            self.unpadded_vocab_size,
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
            padding_size=DEFAULT_VOCAB_PADDING_SIZE
            # We need bigger padding if using lora for kernel
            # compatibility
            if not lora_config else lora_config.lora_vocab_padding_size,
364
            quant_config=quant_config,
Terry's avatar
Terry committed
365
        )
366
367
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
368
369
370
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size)
        self.sampler = Sampler()
371
372
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
373
374
375
376
377

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
378
379
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
380
        intermediate_tensors: Optional[IntermediateTensors] = None,
381
    ) -> Union[torch.Tensor, IntermediateTensors]:
382
        hidden_states = self.model(input_ids, positions, kv_caches,
383
                                   attn_metadata, intermediate_tensors)
Pierre Stock's avatar
Pierre Stock committed
384
385
        return hidden_states

386
387
388
389
390
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
391
        logits = self.logits_processor(self.lm_head, hidden_states,
392
393
394
                                       sampling_metadata)
        return logits

Pierre Stock's avatar
Pierre Stock committed
395
396
    def sample(
        self,
397
        logits: Optional[torch.Tensor],
Pierre Stock's avatar
Pierre Stock committed
398
        sampling_metadata: SamplingMetadata,
399
    ) -> Optional[SamplerOutput]:
400
        next_tokens = self.sampler(logits, sampling_metadata)
Pierre Stock's avatar
Pierre Stock committed
401
402
        return next_tokens

403
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
Pierre Stock's avatar
Pierre Stock committed
404
405
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
406
407
408
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
Pierre Stock's avatar
Pierre Stock committed
409
        ]
410

411
412
413
414
415
416
417
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
        expert_params_mapping = FusedMoE.make_expert_params_mapping(
            ckpt_gate_proj_name="w1",
            ckpt_down_proj_name="w2",
            ckpt_up_proj_name="w3",
            num_experts=self.config.num_local_experts)
Philipp Moritz's avatar
Philipp Moritz committed
418

Pierre Stock's avatar
Pierre Stock committed
419
        params_dict = dict(self.named_parameters())
420
        for name, loaded_weight in weights:
Pierre Stock's avatar
Pierre Stock committed
421
422
            if "rotary_emb.inv_freq" in name:
                continue
Philipp Moritz's avatar
Philipp Moritz committed
423

Pierre Stock's avatar
Pierre Stock committed
424
425
426
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
427
428
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
429
430
                if ((name.endswith(".bias") or name.endswith("_bias"))
                        and name not in params_dict):
CHU Tianxiang's avatar
CHU Tianxiang committed
431
                    continue
432
433
434
435
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue

CHU Tianxiang's avatar
CHU Tianxiang committed
436
                param = params_dict[name]
Pierre Stock's avatar
Pierre Stock committed
437
438
439
440
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
441
442
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
Philipp Moritz's avatar
Philipp Moritz committed
443
444
445
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
446
447
448
                    # Skip layers on other devices.
                    if is_pp_missing_parameter(name, self):
                        continue
449
450
451
                    if ((name.endswith(".bias") or name.endswith("_bias"))
                            and name not in params_dict):
                        continue
Philipp Moritz's avatar
Philipp Moritz committed
452
453
454
455
                    param = params_dict[name]
                    weight_loader = param.weight_loader
                    weight_loader(param,
                                  loaded_weight,
456
                                  name,
457
                                  shard_id=shard_id,
Philipp Moritz's avatar
Philipp Moritz committed
458
459
460
461
                                  expert_id=expert_id)
                    break
                else:
                    # Skip loading extra bias for GPTQ models.
462
463
                    if ((name.endswith(".bias") or name.endswith("_bias"))
                            and name not in params_dict):
Philipp Moritz's avatar
Philipp Moritz committed
464
                        continue
465
466
467
                    # Skip layers on other devices.
                    if is_pp_missing_parameter(name, self):
                        continue
468
                    # Remapping the name of FP8 kv-scale.
469
470
471
472
                    name = maybe_remap_kv_scale_name(name, params_dict)
                    if name is None:
                        continue

Philipp Moritz's avatar
Philipp Moritz committed
473
474
475
476
                    param = params_dict[name]
                    weight_loader = getattr(param, "weight_loader",
                                            default_weight_loader)
                    weight_loader(param, loaded_weight)