mixtral.py 19.2 KB
Newer Older
Pierre Stock's avatar
Pierre Stock committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Mixtral model."""
24
from typing import Iterable, List, Optional, Tuple
Pierre Stock's avatar
Pierre Stock committed
25
26
27

import torch
from torch import nn
28
from transformers import MixtralConfig
Pierre Stock's avatar
Pierre Stock committed
29

30
from vllm.attention import Attention, AttentionMetadata
Terry's avatar
Terry committed
31
from vllm.config import LoRAConfig
32
33
34
from vllm.distributed import (get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size,
                              tensor_model_parallel_all_reduce)
Philipp Moritz's avatar
Philipp Moritz committed
35
from vllm.model_executor.layers.fused_moe import fused_moe
Pierre Stock's avatar
Pierre Stock committed
36
37
38
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
                                               QKVParallelLinear,
Philipp Moritz's avatar
Philipp Moritz committed
39
                                               ReplicatedLinear,
Pierre Stock's avatar
Pierre Stock committed
40
                                               RowParallelLinear)
41
from vllm.model_executor.layers.logits_processor import LogitsProcessor
42
43
from vllm.model_executor.layers.quantization.fp8 import (Fp8LinearMethod,
                                                         per_tensor_quantize)
44
from vllm.model_executor.layers.rotary_embedding import get_rope
Pierre Stock's avatar
Pierre Stock committed
45
46
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
47
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
48
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
Pierre Stock's avatar
Pierre Stock committed
49
from vllm.model_executor.sampling_metadata import SamplingMetadata
Philipp Moritz's avatar
Philipp Moritz committed
50
from vllm.model_executor.utils import set_weight_attrs
Pierre Stock's avatar
Pierre Stock committed
51
from vllm.sequence import SamplerOutput
52
from vllm.utils import print_warning_once
Pierre Stock's avatar
Pierre Stock committed
53
54


Philipp Moritz's avatar
Philipp Moritz committed
55
56
57
58
59
60
61
62
class MixtralMoE(nn.Module):
    """A tensor-parallel MoE implementation for Mixtral that shards each expert
    across all ranks.

    Each expert's weights are sharded across all ranks and a fused MoE
    kernel is used for the forward pass, and finally we reduce the outputs
    across ranks.
    """
63
64
65
66

    def __init__(
        self,
        num_experts: int,
Philipp Moritz's avatar
Philipp Moritz committed
67
        top_k: int,
68
69
        hidden_size: int,
        intermediate_size: int,
Philipp Moritz's avatar
Philipp Moritz committed
70
        params_dtype: Optional[torch.dtype] = None,
71
        tp_size: Optional[int] = None,
72
        linear_method: Optional[LinearMethodBase] = None,
Philipp Moritz's avatar
Philipp Moritz committed
73
    ):
74
        super().__init__()
75
        self.tp_size = tp_size or get_tensor_model_parallel_world_size()
Philipp Moritz's avatar
Philipp Moritz committed
76
77
78
        self.num_total_experts = num_experts
        self.top_k = top_k
        self.hidden_size = hidden_size
79
        self.intermediate_size = intermediate_size // self.tp_size
80
81
82
        # FIXME(pcmoritz): Make this more general to support different
        # quantization schemes
        self.use_fp8 = isinstance(linear_method, Fp8LinearMethod)
83

Philipp Moritz's avatar
Philipp Moritz committed
84
85
86
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype
87

Philipp Moritz's avatar
Philipp Moritz committed
88
        self.gate = ReplicatedLinear(self.hidden_size,
89
90
                                     self.num_total_experts,
                                     bias=False,
Philipp Moritz's avatar
Philipp Moritz committed
91
                                     params_dtype=self.params_dtype,
CHU Tianxiang's avatar
CHU Tianxiang committed
92
                                     linear_method=None)
93

Philipp Moritz's avatar
Philipp Moritz committed
94
95
96
97
98
99
100
101
102
103
104
105
106
        self.ws = nn.Parameter(
            torch.empty(self.num_total_experts,
                        2 * self.intermediate_size,
                        self.hidden_size,
                        device="cuda",
                        dtype=self.params_dtype))
        self.w2s = nn.Parameter(
            torch.empty(self.num_total_experts,
                        self.hidden_size,
                        self.intermediate_size,
                        device="cuda",
                        dtype=self.params_dtype))

107
108
109
110
111
112
113
114
115
116
        # Scaling factors for FP8 weights
        self.ws_scale = nn.Parameter(
            torch.ones(
                self.num_total_experts, device="cuda", dtype=torch.float32),
            requires_grad=False) if self.use_fp8 else None
        self.w2s_scale = nn.Parameter(
            torch.ones(
                self.num_total_experts, device="cuda", dtype=torch.float32),
            requires_grad=False) if self.use_fp8 else None

Philipp Moritz's avatar
Philipp Moritz committed
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
        set_weight_attrs(self.ws, {
            "weight_loader": self.weight_loader,
        })
        set_weight_attrs(self.w2s, {
            "weight_loader": self.weight_loader,
        })

    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
                      weight_name: str, expert_id: int):
        tp_rank = get_tensor_model_parallel_rank()
        param_data = param.data
        shard_size = self.intermediate_size
        shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
        if weight_name.endswith("w1.weight"):
            param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
        if weight_name.endswith("w3.weight"):
            param_data[expert_id,
                       shard_size:2 * shard_size, :] = loaded_weight[shard, :]
        if weight_name.endswith("w2.weight"):
            param_data[expert_id, :, :] = loaded_weight[:, shard]

138
139
140
141
142
143
144
145
146
147
148
149
    def process_weights_after_loading(self):
        if self.use_fp8:
            ws = torch.empty_like(self.ws.data, dtype=torch.float8_e4m3fn)
            w2s = torch.empty_like(self.w2s.data, dtype=torch.float8_e4m3fn)
            for expert in range(self.num_total_experts):
                ws[expert, :, :], self.ws_scale[expert] = per_tensor_quantize(
                    self.ws.data[expert, :, :])
                w2s[expert, :, :], self.w2s_scale[
                    expert] = per_tensor_quantize(self.w2s.data[expert, :, :])
            self.ws = nn.Parameter(ws, requires_grad=False)
            self.w2s = nn.Parameter(w2s, requires_grad=False)

150
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
151
        num_tokens, hidden_size = hidden_states.shape
Philipp Moritz's avatar
Philipp Moritz committed
152
        hidden_states = hidden_states.view(-1, self.hidden_size)
153
        # router_logits: (num_tokens, n_experts)
154
        router_logits, _ = self.gate(hidden_states)
Philipp Moritz's avatar
Philipp Moritz committed
155
156
157
        final_hidden_states = fused_moe(hidden_states,
                                        self.ws,
                                        self.w2s,
158
159
160
                                        router_logits,
                                        self.top_k,
                                        renormalize=True,
161
162
163
164
                                        inplace=True,
                                        use_fp8=self.use_fp8,
                                        w1_scale=self.ws_scale,
                                        w2_scale=self.w2s_scale)
Philipp Moritz's avatar
Philipp Moritz committed
165

166
167
168
        if self.tp_size > 1:
            final_hidden_states = tensor_model_parallel_all_reduce(
                final_hidden_states)
169

170
        return final_hidden_states.view(num_tokens, hidden_size)
Pierre Stock's avatar
Pierre Stock committed
171
172
173
174
175
176
177
178
179
180


class MixtralAttention(nn.Module):

    def __init__(self,
                 hidden_size: int,
                 num_heads: int,
                 num_kv_heads: int,
                 max_position: int = 4096 * 32,
                 rope_theta: float = 10000,
181
                 linear_method: Optional[LinearMethodBase] = None,
Pierre Stock's avatar
Pierre Stock committed
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
                 sliding_window: Optional[int] = None) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta
        self.sliding_window = sliding_window

206
207
208
209
210
211
212
        if isinstance(linear_method, Fp8LinearMethod):
            print_warning_once(
                "For Mixtral FP8 quantization, we currently do not quantize "
                "the attention layers until their FP8 performance is improved."
            )
            linear_method = None

213
        self.qkv_proj = QKVParallelLinear(
Pierre Stock's avatar
Pierre Stock committed
214
215
216
217
218
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
219
            linear_method=linear_method,
Pierre Stock's avatar
Pierre Stock committed
220
        )
221
        self.o_proj = RowParallelLinear(
Pierre Stock's avatar
Pierre Stock committed
222
223
224
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
225
            linear_method=linear_method,
Pierre Stock's avatar
Pierre Stock committed
226
227
228
229
230
231
        )
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position,
            base=int(self.rope_theta),
232
            is_neox_style=True,
Pierre Stock's avatar
Pierre Stock committed
233
        )
234
        self.attn = Attention(
Pierre Stock's avatar
Pierre Stock committed
235
236
237
238
239
240
241
242
243
244
245
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            sliding_window=self.sliding_window,
        )

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
246
247
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Pierre Stock's avatar
Pierre Stock committed
248
    ) -> torch.Tensor:
249
        qkv, _ = self.qkv_proj(hidden_states)
Pierre Stock's avatar
Pierre Stock committed
250
251
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
252
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
253
        output, _ = self.o_proj(attn_output)
Pierre Stock's avatar
Pierre Stock committed
254
255
256
257
258
259
260
        return output


class MixtralDecoderLayer(nn.Module):

    def __init__(
        self,
261
        config: MixtralConfig,
262
        linear_method: Optional[LinearMethodBase] = None,
Pierre Stock's avatar
Pierre Stock committed
263
264
265
266
267
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        # Requires transformers > 4.32.0
        rope_theta = getattr(config, "rope_theta", 10000)
268
        self.self_attn = MixtralAttention(
Pierre Stock's avatar
Pierre Stock committed
269
270
271
272
273
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            max_position=config.max_position_embeddings,
            num_kv_heads=config.num_key_value_heads,
            rope_theta=rope_theta,
274
275
            sliding_window=config.sliding_window,
            linear_method=linear_method)
Philipp Moritz's avatar
Philipp Moritz committed
276
277
278
279
        self.block_sparse_moe = MixtralMoE(
            num_experts=config.num_local_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
280
281
            intermediate_size=config.intermediate_size,
            linear_method=linear_method)
282
283
284
285
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)
Pierre Stock's avatar
Pierre Stock committed
286
287
288
289

    def forward(
        self,
        positions: torch.Tensor,
290
        hidden_states: torch.Tensor,
291
292
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
293
        residual: Optional[torch.Tensor],
Pierre Stock's avatar
Pierre Stock committed
294
    ) -> torch.Tensor:
295
296
297
298
299
300
301
302
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
        hidden_states = self.self_attn(
Pierre Stock's avatar
Pierre Stock committed
303
            positions=positions,
304
            hidden_states=hidden_states,
Pierre Stock's avatar
Pierre Stock committed
305
            kv_cache=kv_cache,
306
            attn_metadata=attn_metadata,
Pierre Stock's avatar
Pierre Stock committed
307
308
        )

309
310
311
312
313
        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
        hidden_states = self.block_sparse_moe(hidden_states)
        return hidden_states, residual
Pierre Stock's avatar
Pierre Stock committed
314

315
316

class MixtralModel(nn.Module):
Pierre Stock's avatar
Pierre Stock committed
317
318
319

    def __init__(
        self,
320
        config: MixtralConfig,
Pierre Stock's avatar
Pierre Stock committed
321
        linear_method: Optional[LinearMethodBase] = None,
322
        lora_config: Optional[LoRAConfig] = None,
Pierre Stock's avatar
Pierre Stock committed
323
324
325
    ) -> None:
        super().__init__()
        self.padding_idx = config.pad_token_id
326
327
328
329
        lora_vocab = (lora_config.lora_extra_vocab_size *
                      (lora_config.max_loras or 1)) if lora_config else 0
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size
330
331

        self.embed_tokens = VocabParallelEmbedding(
332
            self.vocab_size,
Pierre Stock's avatar
Pierre Stock committed
333
            config.hidden_size,
334
            org_num_embeddings=config.vocab_size,
Pierre Stock's avatar
Pierre Stock committed
335
336
        )
        self.layers = nn.ModuleList([
337
            MixtralDecoderLayer(config, linear_method=linear_method)
Pierre Stock's avatar
Pierre Stock committed
338
339
            for _ in range(config.num_hidden_layers)
        ])
340
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Pierre Stock's avatar
Pierre Stock committed
341
342
343
344
345

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
346
347
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
348
    ) -> torch.Tensor:
349
350
        hidden_states = self.embed_tokens(input_ids)
        residual = None
Pierre Stock's avatar
Pierre Stock committed
351
352
        for i in range(len(self.layers)):
            layer = self.layers[i]
353
            hidden_states, residual = layer(positions, hidden_states,
354
                                            kv_caches[i], attn_metadata,
355
                                            residual)
356
357
358
359
360
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states


class MixtralForCausalLM(nn.Module):
361
362
    fall_back_to_pt_during_load = False

Terry's avatar
Terry committed
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
    }

    # LoRA specific attributes
    supported_lora_modules = [
        "qkv_proj",
        "o_proj",
        "embed_tokens",
        "lm_head",
    ]
    embedding_modules = {
        "embed_tokens": "input_embeddings",
        "lm_head": "output_embeddings",
    }
    embedding_padding_modules = ["lm_head"]
383
384
385
386
387

    def __init__(
        self,
        config: MixtralConfig,
        linear_method: Optional[LinearMethodBase] = None,
Terry's avatar
Terry committed
388
        lora_config: Optional[LoRAConfig] = None,
389
390
391
392
    ) -> None:
        super().__init__()
        self.config = config
        self.linear_method = linear_method
393
394
395
        self.model = MixtralModel(config,
                                  linear_method,
                                  lora_config=lora_config)
Terry's avatar
Terry committed
396
397
398
399
400
401
402
403
404
405
406
407
        self.unpadded_vocab_size = config.vocab_size
        if lora_config:
            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
        self.lm_head = ParallelLMHead(
            self.unpadded_vocab_size,
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
            padding_size=DEFAULT_VOCAB_PADDING_SIZE
            # We need bigger padding if using lora for kernel
            # compatibility
            if not lora_config else lora_config.lora_vocab_padding_size,
        )
408
409
410
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size)
        self.sampler = Sampler()
411
412
413
414
415

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
416
417
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
418
419
    ) -> torch.Tensor:
        hidden_states = self.model(input_ids, positions, kv_caches,
420
                                   attn_metadata)
Pierre Stock's avatar
Pierre Stock committed
421
422
        return hidden_states

423
424
425
426
427
428
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head.weight, hidden_states,
                                       sampling_metadata)
        return logits

Pierre Stock's avatar
Pierre Stock committed
429
430
    def sample(
        self,
431
        logits: Optional[torch.Tensor],
Pierre Stock's avatar
Pierre Stock committed
432
        sampling_metadata: SamplingMetadata,
433
    ) -> Optional[SamplerOutput]:
434
        next_tokens = self.sampler(logits, sampling_metadata)
Pierre Stock's avatar
Pierre Stock committed
435
436
        return next_tokens

437
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
Pierre Stock's avatar
Pierre Stock committed
438
439
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
440
441
442
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
Pierre Stock's avatar
Pierre Stock committed
443
        ]
444

Philipp Moritz's avatar
Philipp Moritz committed
445
446
447
448
449
450
451
452
        expert_params_mapping = [
            # (param_name, weight_name, expert_id)
            ("ws" if weight_name in ["w1", "w3"] else "w2s",
             f"experts.{expert_id}.{weight_name}.weight", expert_id)
            for expert_id in range(self.config.num_local_experts)
            for weight_name in ["w1", "w2", "w3"]
        ]

Pierre Stock's avatar
Pierre Stock committed
453
        params_dict = dict(self.named_parameters())
454
        for name, loaded_weight in weights:
Pierre Stock's avatar
Pierre Stock committed
455
456
            if "rotary_emb.inv_freq" in name:
                continue
Philipp Moritz's avatar
Philipp Moritz committed
457

Pierre Stock's avatar
Pierre Stock committed
458
459
460
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
461
462
463
464
465
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
Pierre Stock's avatar
Pierre Stock committed
466
467
468
469
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
Philipp Moritz's avatar
Philipp Moritz committed
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
                for param_name, weight_name, expert_id in expert_params_mapping:
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
                    param = params_dict[name]
                    weight_loader = param.weight_loader
                    weight_loader(param,
                                  loaded_weight,
                                  weight_name,
                                  expert_id=expert_id)
                    break
                else:
                    # Skip loading extra bias for GPTQ models.
                    if name.endswith(".bias") and name not in params_dict:
                        continue
                    param = params_dict[name]
                    weight_loader = getattr(param, "weight_loader",
                                            default_weight_loader)
                    weight_loader(param, loaded_weight)