mixtral.py 18.8 KB
Newer Older
Pierre Stock's avatar
Pierre Stock committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Mixtral model."""
24
from typing import Iterable, List, Optional, Tuple
Pierre Stock's avatar
Pierre Stock committed
25
26
27

import torch
from torch import nn
28
from transformers import MixtralConfig
Pierre Stock's avatar
Pierre Stock committed
29

30
from vllm.attention import Attention, AttentionMetadata
31
from vllm.config import CacheConfig, LoRAConfig
32
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
33
from vllm.model_executor.layers.fused_moe import FusedMoE
Pierre Stock's avatar
Pierre Stock committed
34
from vllm.model_executor.layers.layernorm import RMSNorm
35
from vllm.model_executor.layers.linear import (QKVParallelLinear,
Philipp Moritz's avatar
Philipp Moritz committed
36
                                               ReplicatedLinear,
Pierre Stock's avatar
Pierre Stock committed
37
                                               RowParallelLinear)
38
from vllm.model_executor.layers.logits_processor import LogitsProcessor
39
40
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
41
from vllm.model_executor.layers.rotary_embedding import get_rope
42
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
Pierre Stock's avatar
Pierre Stock committed
43
from vllm.model_executor.layers.vocab_parallel_embedding import (
44
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
45
46
from vllm.model_executor.model_loader.weight_utils import (
    default_weight_loader, maybe_remap_kv_scale_name)
Pierre Stock's avatar
Pierre Stock committed
47
from vllm.model_executor.sampling_metadata import SamplingMetadata
48
from vllm.sequence import IntermediateTensors
Pierre Stock's avatar
Pierre Stock committed
49

50
from .interfaces import SupportsLoRA
51
from .utils import is_pp_missing_parameter, make_layers
52

Pierre Stock's avatar
Pierre Stock committed
53

Philipp Moritz's avatar
Philipp Moritz committed
54
55
56
57
58
59
60
61
class MixtralMoE(nn.Module):
    """A tensor-parallel MoE implementation for Mixtral that shards each expert
    across all ranks.

    Each expert's weights are sharded across all ranks and a fused MoE
    kernel is used for the forward pass, and finally we reduce the outputs
    across ranks.
    """
62

63
64
65
66
67
68
69
    def __init__(self,
                 num_experts: int,
                 top_k: int,
                 hidden_size: int,
                 intermediate_size: int,
                 params_dtype: Optional[torch.dtype] = None,
                 quant_config: Optional[QuantizationConfig] = None,
70
71
                 tp_size: Optional[int] = None,
                 prefix: str = ""):
72
        super().__init__()
Philipp Moritz's avatar
Philipp Moritz committed
73
        self.hidden_size = hidden_size
74

75
        # Gate always runs at half / full precision for now.
76

77
78
        self.gate = ReplicatedLinear(hidden_size,
                                     num_experts,
79
                                     bias=False,
80
                                     params_dtype=params_dtype,
81
82
                                     quant_config=None,
                                     prefix=f"{prefix}.gate")
83

84
85
86
87
88
89
90
91
        self.experts = FusedMoE(num_experts=num_experts,
                                top_k=top_k,
                                hidden_size=hidden_size,
                                intermediate_size=intermediate_size,
                                params_dtype=params_dtype,
                                reduce_results=True,
                                renormalize=True,
                                quant_config=quant_config,
92
93
                                tp_size=tp_size,
                                prefix=f"{prefix}.experts")
94

95
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
96
97
        # NOTE: hidden_states can have either 1D or 2D shape.
        orig_shape = hidden_states.shape
Philipp Moritz's avatar
Philipp Moritz committed
98
        hidden_states = hidden_states.view(-1, self.hidden_size)
99
        # router_logits: (num_tokens, n_experts)
100
        router_logits, _ = self.gate(hidden_states)
101
        final_hidden_states = self.experts(hidden_states, router_logits)
102
        return final_hidden_states.view(orig_shape)
Pierre Stock's avatar
Pierre Stock committed
103
104
105
106


class MixtralAttention(nn.Module):

107
108
109
110
111
112
113
114
115
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        max_position: int = 4096 * 32,
        rope_theta: float = 10000,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
116
        prefix: str = "",
117
    ) -> None:
Pierre Stock's avatar
Pierre Stock committed
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta

140
        self.qkv_proj = QKVParallelLinear(
Pierre Stock's avatar
Pierre Stock committed
141
142
143
144
145
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
146
            quant_config=quant_config,
147
            prefix=f"{prefix}.qkv_proj",
Pierre Stock's avatar
Pierre Stock committed
148
        )
149
        self.o_proj = RowParallelLinear(
Pierre Stock's avatar
Pierre Stock committed
150
151
152
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
153
            quant_config=quant_config,
154
            prefix=f"{prefix}.o_proj",
Pierre Stock's avatar
Pierre Stock committed
155
156
157
158
159
160
        )
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position,
            base=int(self.rope_theta),
161
            is_neox_style=True,
Pierre Stock's avatar
Pierre Stock committed
162
        )
163
164
165
166
167
168
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
                              num_kv_heads=self.num_kv_heads,
                              cache_config=cache_config,
                              quant_config=quant_config)
Pierre Stock's avatar
Pierre Stock committed
169
170
171
172
173

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
174
175
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Pierre Stock's avatar
Pierre Stock committed
176
    ) -> torch.Tensor:
177
        qkv, _ = self.qkv_proj(hidden_states)
Pierre Stock's avatar
Pierre Stock committed
178
179
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
180
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
181
        output, _ = self.o_proj(attn_output)
Pierre Stock's avatar
Pierre Stock committed
182
183
184
185
186
187
188
        return output


class MixtralDecoderLayer(nn.Module):

    def __init__(
        self,
189
        config: MixtralConfig,
190
        cache_config: Optional[CacheConfig] = None,
191
        quant_config: Optional[QuantizationConfig] = None,
192
        prefix: str = "",
Pierre Stock's avatar
Pierre Stock committed
193
194
195
196
197
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        # Requires transformers > 4.32.0
        rope_theta = getattr(config, "rope_theta", 10000)
198
        self.self_attn = MixtralAttention(
Pierre Stock's avatar
Pierre Stock committed
199
200
201
202
203
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            max_position=config.max_position_embeddings,
            num_kv_heads=config.num_key_value_heads,
            rope_theta=rope_theta,
204
            cache_config=cache_config,
205
206
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn")
Philipp Moritz's avatar
Philipp Moritz committed
207
208
209
210
        self.block_sparse_moe = MixtralMoE(
            num_experts=config.num_local_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
211
            intermediate_size=config.intermediate_size,
212
213
            quant_config=quant_config,
            prefix=f"{prefix}.block_sparse_moe")
214
215
216
217
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)
Pierre Stock's avatar
Pierre Stock committed
218
219
220
221

    def forward(
        self,
        positions: torch.Tensor,
222
        hidden_states: torch.Tensor,
223
224
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
225
        residual: Optional[torch.Tensor],
Pierre Stock's avatar
Pierre Stock committed
226
    ) -> torch.Tensor:
227
228
229
230
231
232
233
234
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
        hidden_states = self.self_attn(
Pierre Stock's avatar
Pierre Stock committed
235
            positions=positions,
236
            hidden_states=hidden_states,
Pierre Stock's avatar
Pierre Stock committed
237
            kv_cache=kv_cache,
238
            attn_metadata=attn_metadata,
Pierre Stock's avatar
Pierre Stock committed
239
240
        )

241
242
243
244
245
        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
        hidden_states = self.block_sparse_moe(hidden_states)
        return hidden_states, residual
Pierre Stock's avatar
Pierre Stock committed
246

247
248

class MixtralModel(nn.Module):
Pierre Stock's avatar
Pierre Stock committed
249
250
251

    def __init__(
        self,
252
        config: MixtralConfig,
253
        cache_config: Optional[CacheConfig] = None,
254
        quant_config: Optional[QuantizationConfig] = None,
255
        lora_config: Optional[LoRAConfig] = None,
256
        prefix: str = "",
Pierre Stock's avatar
Pierre Stock committed
257
258
259
    ) -> None:
        super().__init__()
        self.padding_idx = config.pad_token_id
260
261
262
263
        lora_vocab = (lora_config.lora_extra_vocab_size *
                      (lora_config.max_loras or 1)) if lora_config else 0
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size
264
265

        self.embed_tokens = VocabParallelEmbedding(
266
            self.vocab_size,
Pierre Stock's avatar
Pierre Stock committed
267
            config.hidden_size,
268
            org_num_embeddings=config.vocab_size,
Pierre Stock's avatar
Pierre Stock committed
269
        )
270
271

        self.start_layer, self.end_layer, self.layers = make_layers(
272
273
274
275
276
            config.num_hidden_layers,
            lambda prefix: MixtralDecoderLayer(
                config, cache_config, quant_config=quant_config, prefix=prefix
            ),
            prefix=f"{prefix}.layers")
277

278
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Pierre Stock's avatar
Pierre Stock committed
279
280
281
282
283

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
284
285
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
286
        intermediate_tensors: Optional[IntermediateTensors],
287
    ) -> torch.Tensor:
288
289
290
291
292
293
294
295
        if get_pp_group().is_first_rank:
            hidden_states = self.embed_tokens(input_ids)
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
        for i in range(self.start_layer, self.end_layer):
Pierre Stock's avatar
Pierre Stock committed
296
            layer = self.layers[i]
297
            hidden_states, residual = layer(positions, hidden_states,
298
299
300
301
302
303
304
                                            kv_caches[i - self.start_layer],
                                            attn_metadata, residual)
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })
305
306
307
308
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states


309
class MixtralForCausalLM(nn.Module, SupportsLoRA):
310
311
    fall_back_to_pt_during_load = False

Terry's avatar
Terry committed
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
    }

    # LoRA specific attributes
    supported_lora_modules = [
        "qkv_proj",
        "o_proj",
        "embed_tokens",
        "lm_head",
    ]
    embedding_modules = {
        "embed_tokens": "input_embeddings",
        "lm_head": "output_embeddings",
    }
    embedding_padding_modules = ["lm_head"]
332
333
334
335

    def __init__(
        self,
        config: MixtralConfig,
336
        cache_config: Optional[CacheConfig] = None,
337
        quant_config: Optional[QuantizationConfig] = None,
Terry's avatar
Terry committed
338
        lora_config: Optional[LoRAConfig] = None,
339
340
    ) -> None:
        super().__init__()
341

342
        self.config = config
343
344
        self.lora_config = lora_config

345
        self.model = MixtralModel(config,
346
                                  cache_config,
347
                                  quant_config,
348
349
                                  lora_config=lora_config,
                                  prefix="model")
Terry's avatar
Terry committed
350
351
352
353
354
355
356
357
358
359
360
        self.unpadded_vocab_size = config.vocab_size
        if lora_config:
            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
        self.lm_head = ParallelLMHead(
            self.unpadded_vocab_size,
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
            padding_size=DEFAULT_VOCAB_PADDING_SIZE
            # We need bigger padding if using lora for kernel
            # compatibility
            if not lora_config else lora_config.lora_vocab_padding_size,
361
            quant_config=quant_config,
Terry's avatar
Terry committed
362
        )
363
364
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
365
366
367
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size)
        self.sampler = Sampler()
368
369
370
371
372

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
373
374
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
375
        intermediate_tensors: Optional[IntermediateTensors] = None,
376
377
    ) -> torch.Tensor:
        hidden_states = self.model(input_ids, positions, kv_caches,
378
                                   attn_metadata, intermediate_tensors)
Pierre Stock's avatar
Pierre Stock committed
379
380
        return hidden_states

381
382
383
384
385
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
386
        logits = self.logits_processor(self.lm_head, hidden_states,
387
388
389
                                       sampling_metadata)
        return logits

390
391
392
393
394
395
396
397
398
399
400
401
402
403
    def make_empty_intermediate_tensors(
            self, batch_size: int, dtype: torch.dtype,
            device: torch.device) -> IntermediateTensors:
        return IntermediateTensors({
            "hidden_states":
            torch.zeros((batch_size, self.config.hidden_size),
                        dtype=dtype,
                        device=device),
            "residual":
            torch.zeros((batch_size, self.config.hidden_size),
                        dtype=dtype,
                        device=device),
        })

Pierre Stock's avatar
Pierre Stock committed
404
405
    def sample(
        self,
406
        logits: Optional[torch.Tensor],
Pierre Stock's avatar
Pierre Stock committed
407
        sampling_metadata: SamplingMetadata,
408
    ) -> Optional[SamplerOutput]:
409
        next_tokens = self.sampler(logits, sampling_metadata)
Pierre Stock's avatar
Pierre Stock committed
410
411
        return next_tokens

412
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
Pierre Stock's avatar
Pierre Stock committed
413
414
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
415
416
417
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
Pierre Stock's avatar
Pierre Stock committed
418
        ]
419

420
421
422
423
424
425
426
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
        expert_params_mapping = FusedMoE.make_expert_params_mapping(
            ckpt_gate_proj_name="w1",
            ckpt_down_proj_name="w2",
            ckpt_up_proj_name="w3",
            num_experts=self.config.num_local_experts)
Philipp Moritz's avatar
Philipp Moritz committed
427

Pierre Stock's avatar
Pierre Stock committed
428
        params_dict = dict(self.named_parameters())
429
        for name, loaded_weight in weights:
Pierre Stock's avatar
Pierre Stock committed
430
431
            if "rotary_emb.inv_freq" in name:
                continue
Philipp Moritz's avatar
Philipp Moritz committed
432

Pierre Stock's avatar
Pierre Stock committed
433
434
435
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
436
437
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
438
439
                if ((name.endswith(".bias") or name.endswith("_bias"))
                        and name not in params_dict):
CHU Tianxiang's avatar
CHU Tianxiang committed
440
                    continue
441
442
443
444
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue

CHU Tianxiang's avatar
CHU Tianxiang committed
445
                param = params_dict[name]
Pierre Stock's avatar
Pierre Stock committed
446
447
448
449
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
450
451
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
Philipp Moritz's avatar
Philipp Moritz committed
452
453
454
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
455
456
457
                    # Skip layers on other devices.
                    if is_pp_missing_parameter(name, self):
                        continue
458
459
460
                    if ((name.endswith(".bias") or name.endswith("_bias"))
                            and name not in params_dict):
                        continue
Philipp Moritz's avatar
Philipp Moritz committed
461
462
463
464
                    param = params_dict[name]
                    weight_loader = param.weight_loader
                    weight_loader(param,
                                  loaded_weight,
465
                                  name,
466
                                  shard_id=shard_id,
Philipp Moritz's avatar
Philipp Moritz committed
467
468
469
470
                                  expert_id=expert_id)
                    break
                else:
                    # Skip loading extra bias for GPTQ models.
471
472
                    if ((name.endswith(".bias") or name.endswith("_bias"))
                            and name not in params_dict):
Philipp Moritz's avatar
Philipp Moritz committed
473
                        continue
474
475
476
                    # Skip layers on other devices.
                    if is_pp_missing_parameter(name, self):
                        continue
477
                    # Remapping the name of FP8 kv-scale.
478
479
480
481
                    name = maybe_remap_kv_scale_name(name, params_dict)
                    if name is None:
                        continue

Philipp Moritz's avatar
Philipp Moritz committed
482
483
484
485
                    param = params_dict[name]
                    weight_loader = getattr(param, "weight_loader",
                                            default_weight_loader)
                    weight_loader(param, loaded_weight)