mixtral.py 19.7 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

Pierre Stock's avatar
Pierre Stock committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Mixtral model."""
25
from typing import Iterable, Optional, Set, Tuple, Union
Pierre Stock's avatar
Pierre Stock committed
26
27
28

import torch
from torch import nn
29
from transformers import MixtralConfig
Pierre Stock's avatar
Pierre Stock committed
30

31
from vllm.attention import Attention
32
from vllm.compilation.decorators import support_torch_compile
33
from vllm.config import CacheConfig, VllmConfig
34
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
35
from vllm.model_executor.layers.fused_moe import FusedMoE
Pierre Stock's avatar
Pierre Stock committed
36
from vllm.model_executor.layers.layernorm import RMSNorm
37
from vllm.model_executor.layers.linear import (QKVParallelLinear,
Philipp Moritz's avatar
Philipp Moritz committed
38
                                               ReplicatedLinear,
Pierre Stock's avatar
Pierre Stock committed
39
                                               RowParallelLinear)
40
from vllm.model_executor.layers.logits_processor import LogitsProcessor
41
from vllm.model_executor.layers.quantization import QuantizationConfig
42
from vllm.model_executor.layers.rotary_embedding import get_rope
Joe Runde's avatar
Joe Runde committed
43
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
Pierre Stock's avatar
Pierre Stock committed
44
from vllm.model_executor.layers.vocab_parallel_embedding import (
45
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
46
47
from vllm.model_executor.model_loader.weight_utils import (
    default_weight_loader, maybe_remap_kv_scale_name)
Pierre Stock's avatar
Pierre Stock committed
48
from vllm.model_executor.sampling_metadata import SamplingMetadata
49
from vllm.sequence import IntermediateTensors
Pierre Stock's avatar
Pierre Stock committed
50

51
52
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
53
54
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
55

Pierre Stock's avatar
Pierre Stock committed
56

Philipp Moritz's avatar
Philipp Moritz committed
57
58
59
60
61
62
63
64
class MixtralMoE(nn.Module):
    """A tensor-parallel MoE implementation for Mixtral that shards each expert
    across all ranks.

    Each expert's weights are sharded across all ranks and a fused MoE
    kernel is used for the forward pass, and finally we reduce the outputs
    across ranks.
    """
65

66
67
68
69
70
71
72
    def __init__(self,
                 num_experts: int,
                 top_k: int,
                 hidden_size: int,
                 intermediate_size: int,
                 params_dtype: Optional[torch.dtype] = None,
                 quant_config: Optional[QuantizationConfig] = None,
73
                 tp_size: Optional[int] = None,
74
                 dp_size: Optional[int] = None,
75
                 prefix: str = ""):
76
        super().__init__()
Philipp Moritz's avatar
Philipp Moritz committed
77
        self.hidden_size = hidden_size
78

79
        # Gate always runs at half / full precision for now.
80

81
82
        self.gate = ReplicatedLinear(hidden_size,
                                     num_experts,
83
                                     bias=False,
84
                                     params_dtype=params_dtype,
85
86
                                     quant_config=None,
                                     prefix=f"{prefix}.gate")
87

88
89
90
91
92
93
94
95
        self.experts = FusedMoE(num_experts=num_experts,
                                top_k=top_k,
                                hidden_size=hidden_size,
                                intermediate_size=intermediate_size,
                                params_dtype=params_dtype,
                                reduce_results=True,
                                renormalize=True,
                                quant_config=quant_config,
96
                                tp_size=tp_size,
97
                                dp_size=dp_size,
98
                                prefix=f"{prefix}.experts")
99

100
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
101
102
        # NOTE: hidden_states can have either 1D or 2D shape.
        orig_shape = hidden_states.shape
Philipp Moritz's avatar
Philipp Moritz committed
103
        hidden_states = hidden_states.view(-1, self.hidden_size)
104
        # router_logits: (num_tokens, n_experts)
105
        router_logits, _ = self.gate(hidden_states)
106
        final_hidden_states = self.experts(hidden_states, router_logits)
107
        return final_hidden_states.view(orig_shape)
Pierre Stock's avatar
Pierre Stock committed
108
109
110
111


class MixtralAttention(nn.Module):

112
113
    def __init__(
        self,
114
        config: MixtralConfig,
115
116
117
118
119
120
121
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        max_position: int = 4096 * 32,
        rope_theta: float = 10000,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
122
        prefix: str = "",
123
    ) -> None:
Pierre Stock's avatar
Pierre Stock committed
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
140
141
142
        # MixtralConfig has an optional head_dim argument
        self.head_dim = getattr(config, "head_dim",
                                self.hidden_size // self.total_num_heads)
Pierre Stock's avatar
Pierre Stock committed
143
144
145
146
147
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta

148
        self.qkv_proj = QKVParallelLinear(
Pierre Stock's avatar
Pierre Stock committed
149
150
151
152
153
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
154
            quant_config=quant_config,
155
            prefix=f"{prefix}.qkv_proj",
Pierre Stock's avatar
Pierre Stock committed
156
        )
157
        self.o_proj = RowParallelLinear(
Pierre Stock's avatar
Pierre Stock committed
158
159
160
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
161
            quant_config=quant_config,
162
            prefix=f"{prefix}.o_proj",
Pierre Stock's avatar
Pierre Stock committed
163
164
165
166
167
168
        )
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position,
            base=int(self.rope_theta),
169
            is_neox_style=True,
Pierre Stock's avatar
Pierre Stock committed
170
        )
171
172
173
174
175
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
                              num_kv_heads=self.num_kv_heads,
                              cache_config=cache_config,
176
177
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
Pierre Stock's avatar
Pierre Stock committed
178
179
180
181
182
183

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
184
        qkv, _ = self.qkv_proj(hidden_states)
Pierre Stock's avatar
Pierre Stock committed
185
186
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
187
        attn_output = self.attn(q, k, v)
188
        output, _ = self.o_proj(attn_output)
Pierre Stock's avatar
Pierre Stock committed
189
190
191
192
193
194
195
        return output


class MixtralDecoderLayer(nn.Module):

    def __init__(
        self,
196
        config: MixtralConfig,
197
        cache_config: Optional[CacheConfig] = None,
198
        quant_config: Optional[QuantizationConfig] = None,
199
        prefix: str = "",
Pierre Stock's avatar
Pierre Stock committed
200
201
202
203
204
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        # Requires transformers > 4.32.0
        rope_theta = getattr(config, "rope_theta", 10000)
205
        self.self_attn = MixtralAttention(
206
            config=config,
Pierre Stock's avatar
Pierre Stock committed
207
208
209
210
211
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            max_position=config.max_position_embeddings,
            num_kv_heads=config.num_key_value_heads,
            rope_theta=rope_theta,
212
            cache_config=cache_config,
213
214
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn")
Philipp Moritz's avatar
Philipp Moritz committed
215
216
217
218
        self.block_sparse_moe = MixtralMoE(
            num_experts=config.num_local_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
219
            intermediate_size=config.intermediate_size,
220
221
            quant_config=quant_config,
            prefix=f"{prefix}.block_sparse_moe")
222
223
224
225
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)
Pierre Stock's avatar
Pierre Stock committed
226
227
228
229

    def forward(
        self,
        positions: torch.Tensor,
230
231
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
Pierre Stock's avatar
Pierre Stock committed
232
    ) -> torch.Tensor:
233
234
235
236
237
238
239
240
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
        hidden_states = self.self_attn(
Pierre Stock's avatar
Pierre Stock committed
241
            positions=positions,
242
            hidden_states=hidden_states,
Pierre Stock's avatar
Pierre Stock committed
243
244
        )

245
246
247
248
249
        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
        hidden_states = self.block_sparse_moe(hidden_states)
        return hidden_states, residual
Pierre Stock's avatar
Pierre Stock committed
250

251

252
@support_torch_compile
253
class MixtralModel(nn.Module):
Pierre Stock's avatar
Pierre Stock committed
254

255
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Pierre Stock's avatar
Pierre Stock committed
256
        super().__init__()
257
258
259
260
261
262

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config

263
264
265
266
        lora_vocab = (lora_config.lora_extra_vocab_size *
                      (lora_config.max_loras or 1)) if lora_config else 0
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size
267
268

        self.embed_tokens = VocabParallelEmbedding(
269
            self.vocab_size,
Pierre Stock's avatar
Pierre Stock committed
270
            config.hidden_size,
271
            org_num_embeddings=config.vocab_size,
Pierre Stock's avatar
Pierre Stock committed
272
        )
273
274

        self.start_layer, self.end_layer, self.layers = make_layers(
275
276
277
278
279
            config.num_hidden_layers,
            lambda prefix: MixtralDecoderLayer(
                config, cache_config, quant_config=quant_config, prefix=prefix
            ),
            prefix=f"{prefix}.layers")
280

281
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
282
283
284
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
Pierre Stock's avatar
Pierre Stock committed
285

286
287
288
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

Pierre Stock's avatar
Pierre Stock committed
289
290
291
292
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
293
        intermediate_tensors: Optional[IntermediateTensors],
294
        inputs_embeds: Optional[torch.Tensor] = None,
295
    ) -> Union[torch.Tensor, IntermediateTensors]:
296
        if get_pp_group().is_first_rank:
297
298
299
300
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
301
302
303
304
305
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
306
307
        for layer in self.layers[self.start_layer:self.end_layer]:
            hidden_states, residual = layer(positions, hidden_states, residual)
308
309
310
311
312
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })
313
314
315
316
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states


317
class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
318
319
    fall_back_to_pt_during_load = False

Terry's avatar
Terry committed
320
321
322
323
324
325
326
327
328
329
330
331
332
333
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
    }

    # LoRA specific attributes
    embedding_modules = {
        "embed_tokens": "input_embeddings",
        "lm_head": "output_embeddings",
    }
    embedding_padding_modules = ["lm_head"]
334

335
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
336
        super().__init__()
337
338
339
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
340
        self.config = config
341
        self.lora_config = lora_config
342
        self.quant_config = quant_config
343

344
345
        self.model = MixtralModel(vllm_config=vllm_config,
                                  prefix=maybe_prefix(prefix, "model"))
Terry's avatar
Terry committed
346
347
348
349
350
351
352
353
354
355
356
        self.unpadded_vocab_size = config.vocab_size
        if lora_config:
            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
        self.lm_head = ParallelLMHead(
            self.unpadded_vocab_size,
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
            padding_size=DEFAULT_VOCAB_PADDING_SIZE
            # We need bigger padding if using lora for kernel
            # compatibility
            if not lora_config else lora_config.lora_vocab_padding_size,
357
            quant_config=quant_config,
Terry's avatar
Terry committed
358
        )
359
360
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
361
362
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size)
Joe Runde's avatar
Joe Runde committed
363
        self.sampler = get_sampler()
364
365
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
366

367
368
369
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

370
371
372
373
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
374
        intermediate_tensors: Optional[IntermediateTensors] = None,
375
        inputs_embeds: Optional[torch.Tensor] = None,
376
    ) -> Union[torch.Tensor, IntermediateTensors]:
377
        hidden_states = self.model(input_ids, positions, intermediate_tensors,
378
                                   inputs_embeds)
Pierre Stock's avatar
Pierre Stock committed
379
380
        return hidden_states

381
382
383
384
385
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
386
        logits = self.logits_processor(self.lm_head, hidden_states,
387
388
389
                                       sampling_metadata)
        return logits

Pierre Stock's avatar
Pierre Stock committed
390
391
    def sample(
        self,
392
        logits: Optional[torch.Tensor],
Pierre Stock's avatar
Pierre Stock committed
393
        sampling_metadata: SamplingMetadata,
394
    ) -> Optional[SamplerOutput]:
395
        next_tokens = self.sampler(logits, sampling_metadata)
Pierre Stock's avatar
Pierre Stock committed
396
397
        return next_tokens

398
399
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
Pierre Stock's avatar
Pierre Stock committed
400
401
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
402
403
404
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
Pierre Stock's avatar
Pierre Stock committed
405
        ]
406

407
408
409
410
411
412
413
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
        expert_params_mapping = FusedMoE.make_expert_params_mapping(
            ckpt_gate_proj_name="w1",
            ckpt_down_proj_name="w2",
            ckpt_up_proj_name="w3",
            num_experts=self.config.num_local_experts)
Philipp Moritz's avatar
Philipp Moritz committed
414

Pierre Stock's avatar
Pierre Stock committed
415
        params_dict = dict(self.named_parameters())
416
        loaded_params: Set[str] = set()
417
        for name, loaded_weight in weights:
Pierre Stock's avatar
Pierre Stock committed
418
419
            if "rotary_emb.inv_freq" in name:
                continue
Philipp Moritz's avatar
Philipp Moritz committed
420

421
422
            if (self.quant_config is not None and
                (scale_name := self.quant_config.get_cache_scale(name))):
423
                # Loading kv cache quantization scales
424
425
426
427
428
429
430
431
432
                param = params_dict[scale_name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
                                 loaded_weight[0])
                weight_loader(param, loaded_weight)
                loaded_params.add(scale_name)
                continue

Pierre Stock's avatar
Pierre Stock committed
433
434
435
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
436
437
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
438
439
                if ((name.endswith(".bias") or name.endswith("_bias"))
                        and name not in params_dict):
CHU Tianxiang's avatar
CHU Tianxiang committed
440
                    continue
441
442
443
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
444
445
446
447
448
                if name.endswith("scale"):
                    # Remapping the name of FP8 kv-scale.
                    name = maybe_remap_kv_scale_name(name, params_dict)
                    if name is None:
                        continue
CHU Tianxiang's avatar
CHU Tianxiang committed
449
                param = params_dict[name]
Pierre Stock's avatar
Pierre Stock committed
450
451
452
453
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
454
455
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
Philipp Moritz's avatar
Philipp Moritz committed
456
457
458
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
459
460
461
                    # Skip layers on other devices.
                    if is_pp_missing_parameter(name, self):
                        continue
462
463
464
                    if ((name.endswith(".bias") or name.endswith("_bias"))
                            and name not in params_dict):
                        continue
Philipp Moritz's avatar
Philipp Moritz committed
465
466
467
468
                    param = params_dict[name]
                    weight_loader = param.weight_loader
                    weight_loader(param,
                                  loaded_weight,
469
                                  name,
470
                                  shard_id=shard_id,
Philipp Moritz's avatar
Philipp Moritz committed
471
472
473
474
                                  expert_id=expert_id)
                    break
                else:
                    # Skip loading extra bias for GPTQ models.
475
476
                    if ((name.endswith(".bias") or name.endswith("_bias"))
                            and name not in params_dict):
Philipp Moritz's avatar
Philipp Moritz committed
477
                        continue
478
479
480
                    # Skip layers on other devices.
                    if is_pp_missing_parameter(name, self):
                        continue
481
                    # Remapping the name of FP8 kv-scale.
482
483
484
485
                    name = maybe_remap_kv_scale_name(name, params_dict)
                    if name is None:
                        continue

Philipp Moritz's avatar
Philipp Moritz committed
486
487
488
489
                    param = params_dict[name]
                    weight_loader = getattr(param, "weight_loader",
                                            default_weight_loader)
                    weight_loader(param, loaded_weight)
490
491
            loaded_params.add(name)
        return loaded_params