mixtral.py 18.6 KB
Newer Older
Pierre Stock's avatar
Pierre Stock committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Mixtral model."""
23
from typing import Iterable, List, Optional, Tuple, Union
Pierre Stock's avatar
Pierre Stock committed
24
25
26

import torch
from torch import nn
27
from transformers import MixtralConfig
Pierre Stock's avatar
Pierre Stock committed
28

29
from vllm.attention import Attention, AttentionMetadata
30
from vllm.compilation.decorators import support_torch_compile
31
from vllm.config import CacheConfig, VllmConfig
32
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
33
from vllm.model_executor.layers.fused_moe import FusedMoE
Pierre Stock's avatar
Pierre Stock committed
34
from vllm.model_executor.layers.layernorm import RMSNorm
35
from vllm.model_executor.layers.linear import (QKVParallelLinear,
Philipp Moritz's avatar
Philipp Moritz committed
36
                                               ReplicatedLinear,
Pierre Stock's avatar
Pierre Stock committed
37
                                               RowParallelLinear)
38
from vllm.model_executor.layers.logits_processor import LogitsProcessor
39
from vllm.model_executor.layers.quantization import QuantizationConfig
40
from vllm.model_executor.layers.rotary_embedding import get_rope
Joe Runde's avatar
Joe Runde committed
41
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
Pierre Stock's avatar
Pierre Stock committed
42
from vllm.model_executor.layers.vocab_parallel_embedding import (
43
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
44
45
from vllm.model_executor.model_loader.weight_utils import (
    default_weight_loader, maybe_remap_kv_scale_name)
Pierre Stock's avatar
Pierre Stock committed
46
from vllm.model_executor.sampling_metadata import SamplingMetadata
47
from vllm.sequence import IntermediateTensors
Pierre Stock's avatar
Pierre Stock committed
48

49
50
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
51
52
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
53

Pierre Stock's avatar
Pierre Stock committed
54

Philipp Moritz's avatar
Philipp Moritz committed
55
56
57
58
59
60
61
62
class MixtralMoE(nn.Module):
    """A tensor-parallel MoE implementation for Mixtral that shards each expert
    across all ranks.

    Each expert's weights are sharded across all ranks and a fused MoE
    kernel is used for the forward pass, and finally we reduce the outputs
    across ranks.
    """
63

64
65
66
67
68
69
70
    def __init__(self,
                 num_experts: int,
                 top_k: int,
                 hidden_size: int,
                 intermediate_size: int,
                 params_dtype: Optional[torch.dtype] = None,
                 quant_config: Optional[QuantizationConfig] = None,
71
72
                 tp_size: Optional[int] = None,
                 prefix: str = ""):
73
        super().__init__()
Philipp Moritz's avatar
Philipp Moritz committed
74
        self.hidden_size = hidden_size
75

76
        # Gate always runs at half / full precision for now.
77

78
79
        self.gate = ReplicatedLinear(hidden_size,
                                     num_experts,
80
                                     bias=False,
81
                                     params_dtype=params_dtype,
82
83
                                     quant_config=None,
                                     prefix=f"{prefix}.gate")
84

85
86
87
88
89
90
91
92
        self.experts = FusedMoE(num_experts=num_experts,
                                top_k=top_k,
                                hidden_size=hidden_size,
                                intermediate_size=intermediate_size,
                                params_dtype=params_dtype,
                                reduce_results=True,
                                renormalize=True,
                                quant_config=quant_config,
93
94
                                tp_size=tp_size,
                                prefix=f"{prefix}.experts")
95

96
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
97
98
        # NOTE: hidden_states can have either 1D or 2D shape.
        orig_shape = hidden_states.shape
Philipp Moritz's avatar
Philipp Moritz committed
99
        hidden_states = hidden_states.view(-1, self.hidden_size)
100
        # router_logits: (num_tokens, n_experts)
101
        router_logits, _ = self.gate(hidden_states)
102
        final_hidden_states = self.experts(hidden_states, router_logits)
103
        return final_hidden_states.view(orig_shape)
Pierre Stock's avatar
Pierre Stock committed
104
105
106
107


class MixtralAttention(nn.Module):

108
109
110
111
112
113
114
115
116
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        max_position: int = 4096 * 32,
        rope_theta: float = 10000,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
117
        prefix: str = "",
118
    ) -> None:
Pierre Stock's avatar
Pierre Stock committed
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta

141
        self.qkv_proj = QKVParallelLinear(
Pierre Stock's avatar
Pierre Stock committed
142
143
144
145
146
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
147
            quant_config=quant_config,
148
            prefix=f"{prefix}.qkv_proj",
Pierre Stock's avatar
Pierre Stock committed
149
        )
150
        self.o_proj = RowParallelLinear(
Pierre Stock's avatar
Pierre Stock committed
151
152
153
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
154
            quant_config=quant_config,
155
            prefix=f"{prefix}.o_proj",
Pierre Stock's avatar
Pierre Stock committed
156
157
158
159
160
161
        )
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position,
            base=int(self.rope_theta),
162
            is_neox_style=True,
Pierre Stock's avatar
Pierre Stock committed
163
        )
164
165
166
167
168
169
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
                              num_kv_heads=self.num_kv_heads,
                              cache_config=cache_config,
                              quant_config=quant_config)
Pierre Stock's avatar
Pierre Stock committed
170
171
172
173
174

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
175
176
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Pierre Stock's avatar
Pierre Stock committed
177
    ) -> torch.Tensor:
178
        qkv, _ = self.qkv_proj(hidden_states)
Pierre Stock's avatar
Pierre Stock committed
179
180
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
181
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
182
        output, _ = self.o_proj(attn_output)
Pierre Stock's avatar
Pierre Stock committed
183
184
185
186
187
188
189
        return output


class MixtralDecoderLayer(nn.Module):

    def __init__(
        self,
190
        config: MixtralConfig,
191
        cache_config: Optional[CacheConfig] = None,
192
        quant_config: Optional[QuantizationConfig] = None,
193
        prefix: str = "",
Pierre Stock's avatar
Pierre Stock committed
194
195
196
197
198
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        # Requires transformers > 4.32.0
        rope_theta = getattr(config, "rope_theta", 10000)
199
        self.self_attn = MixtralAttention(
Pierre Stock's avatar
Pierre Stock committed
200
201
202
203
204
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            max_position=config.max_position_embeddings,
            num_kv_heads=config.num_key_value_heads,
            rope_theta=rope_theta,
205
            cache_config=cache_config,
206
207
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn")
Philipp Moritz's avatar
Philipp Moritz committed
208
209
210
211
        self.block_sparse_moe = MixtralMoE(
            num_experts=config.num_local_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
212
            intermediate_size=config.intermediate_size,
213
214
            quant_config=quant_config,
            prefix=f"{prefix}.block_sparse_moe")
215
216
217
218
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)
Pierre Stock's avatar
Pierre Stock committed
219
220
221
222

    def forward(
        self,
        positions: torch.Tensor,
223
        hidden_states: torch.Tensor,
224
225
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
226
        residual: Optional[torch.Tensor],
Pierre Stock's avatar
Pierre Stock committed
227
    ) -> torch.Tensor:
228
229
230
231
232
233
234
235
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
        hidden_states = self.self_attn(
Pierre Stock's avatar
Pierre Stock committed
236
            positions=positions,
237
            hidden_states=hidden_states,
Pierre Stock's avatar
Pierre Stock committed
238
            kv_cache=kv_cache,
239
            attn_metadata=attn_metadata,
Pierre Stock's avatar
Pierre Stock committed
240
241
        )

242
243
244
245
246
        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
        hidden_states = self.block_sparse_moe(hidden_states)
        return hidden_states, residual
Pierre Stock's avatar
Pierre Stock committed
247

248

249
@support_torch_compile
250
class MixtralModel(nn.Module):
Pierre Stock's avatar
Pierre Stock committed
251

252
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Pierre Stock's avatar
Pierre Stock committed
253
        super().__init__()
254
255
256
257
258
259

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config

Pierre Stock's avatar
Pierre Stock committed
260
        self.padding_idx = config.pad_token_id
261
262
263
264
        lora_vocab = (lora_config.lora_extra_vocab_size *
                      (lora_config.max_loras or 1)) if lora_config else 0
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size
265
266

        self.embed_tokens = VocabParallelEmbedding(
267
            self.vocab_size,
Pierre Stock's avatar
Pierre Stock committed
268
            config.hidden_size,
269
            org_num_embeddings=config.vocab_size,
Pierre Stock's avatar
Pierre Stock committed
270
        )
271
272

        self.start_layer, self.end_layer, self.layers = make_layers(
273
274
275
276
277
            config.num_hidden_layers,
            lambda prefix: MixtralDecoderLayer(
                config, cache_config, quant_config=quant_config, prefix=prefix
            ),
            prefix=f"{prefix}.layers")
278

279
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
280
281
282
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
Pierre Stock's avatar
Pierre Stock committed
283
284
285
286
287

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
288
289
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
290
        intermediate_tensors: Optional[IntermediateTensors],
291
    ) -> Union[torch.Tensor, IntermediateTensors]:
292
293
294
295
296
297
298
299
        if get_pp_group().is_first_rank:
            hidden_states = self.embed_tokens(input_ids)
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
        for i in range(self.start_layer, self.end_layer):
Pierre Stock's avatar
Pierre Stock committed
300
            layer = self.layers[i]
301
            hidden_states, residual = layer(positions, hidden_states,
302
303
304
305
306
307
308
                                            kv_caches[i - self.start_layer],
                                            attn_metadata, residual)
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })
309
310
311
312
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states


313
class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
314
315
    fall_back_to_pt_during_load = False

Terry's avatar
Terry committed
316
317
318
319
320
321
322
323
324
325
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
    }

    # LoRA specific attributes
    supported_lora_modules = [
326
327
        "qkv_proj", "o_proj", "embed_tokens", "lm_head", "w1", "w2", "w3",
        "gate"
Terry's avatar
Terry committed
328
329
330
331
332
333
    ]
    embedding_modules = {
        "embed_tokens": "input_embeddings",
        "lm_head": "output_embeddings",
    }
    embedding_padding_modules = ["lm_head"]
334

335
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
336
        super().__init__()
337
338
339
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
340
        self.config = config
341
342
        self.lora_config = lora_config

343
344
        self.model = MixtralModel(vllm_config=vllm_config,
                                  prefix=maybe_prefix(prefix, "model"))
Terry's avatar
Terry committed
345
346
347
348
349
350
351
352
353
354
355
        self.unpadded_vocab_size = config.vocab_size
        if lora_config:
            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
        self.lm_head = ParallelLMHead(
            self.unpadded_vocab_size,
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
            padding_size=DEFAULT_VOCAB_PADDING_SIZE
            # We need bigger padding if using lora for kernel
            # compatibility
            if not lora_config else lora_config.lora_vocab_padding_size,
356
            quant_config=quant_config,
Terry's avatar
Terry committed
357
        )
358
359
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
360
361
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size)
Joe Runde's avatar
Joe Runde committed
362
        self.sampler = get_sampler()
363
364
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
365
366
367
368
369

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
370
371
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
372
        intermediate_tensors: Optional[IntermediateTensors] = None,
373
    ) -> Union[torch.Tensor, IntermediateTensors]:
374
        hidden_states = self.model(input_ids, positions, kv_caches,
375
                                   attn_metadata, intermediate_tensors)
Pierre Stock's avatar
Pierre Stock committed
376
377
        return hidden_states

378
379
380
381
382
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
383
        logits = self.logits_processor(self.lm_head, hidden_states,
384
385
386
                                       sampling_metadata)
        return logits

Pierre Stock's avatar
Pierre Stock committed
387
388
    def sample(
        self,
389
        logits: Optional[torch.Tensor],
Pierre Stock's avatar
Pierre Stock committed
390
        sampling_metadata: SamplingMetadata,
391
    ) -> Optional[SamplerOutput]:
392
        next_tokens = self.sampler(logits, sampling_metadata)
Pierre Stock's avatar
Pierre Stock committed
393
394
        return next_tokens

395
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
Pierre Stock's avatar
Pierre Stock committed
396
397
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
398
399
400
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
Pierre Stock's avatar
Pierre Stock committed
401
        ]
402

403
404
405
406
407
408
409
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
        expert_params_mapping = FusedMoE.make_expert_params_mapping(
            ckpt_gate_proj_name="w1",
            ckpt_down_proj_name="w2",
            ckpt_up_proj_name="w3",
            num_experts=self.config.num_local_experts)
Philipp Moritz's avatar
Philipp Moritz committed
410

Pierre Stock's avatar
Pierre Stock committed
411
        params_dict = dict(self.named_parameters())
412
        for name, loaded_weight in weights:
Pierre Stock's avatar
Pierre Stock committed
413
414
            if "rotary_emb.inv_freq" in name:
                continue
Philipp Moritz's avatar
Philipp Moritz committed
415

Pierre Stock's avatar
Pierre Stock committed
416
417
418
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
419
420
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
421
422
                if ((name.endswith(".bias") or name.endswith("_bias"))
                        and name not in params_dict):
CHU Tianxiang's avatar
CHU Tianxiang committed
423
                    continue
424
425
426
427
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue

CHU Tianxiang's avatar
CHU Tianxiang committed
428
                param = params_dict[name]
Pierre Stock's avatar
Pierre Stock committed
429
430
431
432
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
433
434
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
Philipp Moritz's avatar
Philipp Moritz committed
435
436
437
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
438
439
440
                    # Skip layers on other devices.
                    if is_pp_missing_parameter(name, self):
                        continue
441
442
443
                    if ((name.endswith(".bias") or name.endswith("_bias"))
                            and name not in params_dict):
                        continue
Philipp Moritz's avatar
Philipp Moritz committed
444
445
446
447
                    param = params_dict[name]
                    weight_loader = param.weight_loader
                    weight_loader(param,
                                  loaded_weight,
448
                                  name,
449
                                  shard_id=shard_id,
Philipp Moritz's avatar
Philipp Moritz committed
450
451
452
453
                                  expert_id=expert_id)
                    break
                else:
                    # Skip loading extra bias for GPTQ models.
454
455
                    if ((name.endswith(".bias") or name.endswith("_bias"))
                            and name not in params_dict):
Philipp Moritz's avatar
Philipp Moritz committed
456
                        continue
457
458
459
                    # Skip layers on other devices.
                    if is_pp_missing_parameter(name, self):
                        continue
460
                    # Remapping the name of FP8 kv-scale.
461
462
463
464
                    name = maybe_remap_kv_scale_name(name, params_dict)
                    if name is None:
                        continue

Philipp Moritz's avatar
Philipp Moritz committed
465
466
467
468
                    param = params_dict[name]
                    weight_loader = getattr(param, "weight_loader",
                                            default_weight_loader)
                    weight_loader(param, loaded_weight)