mixtral.py 17.2 KB
Newer Older
Pierre Stock's avatar
Pierre Stock committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Mixtral model."""
24
from typing import Iterable, List, Optional, Tuple
Pierre Stock's avatar
Pierre Stock committed
25
26
27

import torch
from torch import nn
28
from transformers import MixtralConfig
Pierre Stock's avatar
Pierre Stock committed
29

30
from vllm.attention import Attention, AttentionMetadata
Terry's avatar
Terry committed
31
from vllm.config import LoRAConfig
32
33
34
from vllm.distributed import (get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size,
                              tensor_model_parallel_all_reduce)
Philipp Moritz's avatar
Philipp Moritz committed
35
from vllm.model_executor.layers.fused_moe import fused_moe
Pierre Stock's avatar
Pierre Stock committed
36
37
38
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
                                               QKVParallelLinear,
Philipp Moritz's avatar
Philipp Moritz committed
39
                                               ReplicatedLinear,
Pierre Stock's avatar
Pierre Stock committed
40
                                               RowParallelLinear)
41
from vllm.model_executor.layers.logits_processor import LogitsProcessor
42
from vllm.model_executor.layers.rotary_embedding import get_rope
Pierre Stock's avatar
Pierre Stock committed
43
44
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
45
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
46
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
Pierre Stock's avatar
Pierre Stock committed
47
from vllm.model_executor.sampling_metadata import SamplingMetadata
Philipp Moritz's avatar
Philipp Moritz committed
48
from vllm.model_executor.utils import set_weight_attrs
Pierre Stock's avatar
Pierre Stock committed
49
50
51
from vllm.sequence import SamplerOutput


Philipp Moritz's avatar
Philipp Moritz committed
52
53
54
55
56
57
58
59
class MixtralMoE(nn.Module):
    """A tensor-parallel MoE implementation for Mixtral that shards each expert
    across all ranks.

    Each expert's weights are sharded across all ranks and a fused MoE
    kernel is used for the forward pass, and finally we reduce the outputs
    across ranks.
    """
60
61
62
63

    def __init__(
        self,
        num_experts: int,
Philipp Moritz's avatar
Philipp Moritz committed
64
        top_k: int,
65
66
        hidden_size: int,
        intermediate_size: int,
Philipp Moritz's avatar
Philipp Moritz committed
67
        params_dtype: Optional[torch.dtype] = None,
68
        tp_size: Optional[int] = None,
Philipp Moritz's avatar
Philipp Moritz committed
69
    ):
70
        super().__init__()
71
        self.tp_size = tp_size or get_tensor_model_parallel_world_size()
Philipp Moritz's avatar
Philipp Moritz committed
72
73
74
        self.num_total_experts = num_experts
        self.top_k = top_k
        self.hidden_size = hidden_size
75
        self.intermediate_size = intermediate_size // self.tp_size
76

Philipp Moritz's avatar
Philipp Moritz committed
77
78
79
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype
80

Philipp Moritz's avatar
Philipp Moritz committed
81
        self.gate = ReplicatedLinear(self.hidden_size,
82
83
                                     self.num_total_experts,
                                     bias=False,
Philipp Moritz's avatar
Philipp Moritz committed
84
                                     params_dtype=self.params_dtype,
CHU Tianxiang's avatar
CHU Tianxiang committed
85
                                     linear_method=None)
86

Philipp Moritz's avatar
Philipp Moritz committed
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
        self.ws = nn.Parameter(
            torch.empty(self.num_total_experts,
                        2 * self.intermediate_size,
                        self.hidden_size,
                        device="cuda",
                        dtype=self.params_dtype))
        self.w2s = nn.Parameter(
            torch.empty(self.num_total_experts,
                        self.hidden_size,
                        self.intermediate_size,
                        device="cuda",
                        dtype=self.params_dtype))

        set_weight_attrs(self.ws, {
            "weight_loader": self.weight_loader,
        })
        set_weight_attrs(self.w2s, {
            "weight_loader": self.weight_loader,
        })

    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
                      weight_name: str, expert_id: int):
        tp_rank = get_tensor_model_parallel_rank()
        param_data = param.data
        shard_size = self.intermediate_size
        shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
        if weight_name.endswith("w1.weight"):
            param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
        if weight_name.endswith("w3.weight"):
            param_data[expert_id,
                       shard_size:2 * shard_size, :] = loaded_weight[shard, :]
        if weight_name.endswith("w2.weight"):
            param_data[expert_id, :, :] = loaded_weight[:, shard]

121
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
122
        num_tokens, hidden_size = hidden_states.shape
Philipp Moritz's avatar
Philipp Moritz committed
123
        hidden_states = hidden_states.view(-1, self.hidden_size)
124
        # router_logits: (num_tokens, n_experts)
125
        router_logits, _ = self.gate(hidden_states)
Philipp Moritz's avatar
Philipp Moritz committed
126
127
128
        final_hidden_states = fused_moe(hidden_states,
                                        self.ws,
                                        self.w2s,
129
130
131
                                        router_logits,
                                        self.top_k,
                                        renormalize=True,
Philipp Moritz's avatar
Philipp Moritz committed
132
133
                                        inplace=True)

134
135
136
        if self.tp_size > 1:
            final_hidden_states = tensor_model_parallel_all_reduce(
                final_hidden_states)
137

138
        return final_hidden_states.view(num_tokens, hidden_size)
Pierre Stock's avatar
Pierre Stock committed
139
140
141
142
143
144
145
146
147
148


class MixtralAttention(nn.Module):

    def __init__(self,
                 hidden_size: int,
                 num_heads: int,
                 num_kv_heads: int,
                 max_position: int = 4096 * 32,
                 rope_theta: float = 10000,
149
                 linear_method: Optional[LinearMethodBase] = None,
Pierre Stock's avatar
Pierre Stock committed
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
                 sliding_window: Optional[int] = None) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta
        self.sliding_window = sliding_window

174
        self.qkv_proj = QKVParallelLinear(
Pierre Stock's avatar
Pierre Stock committed
175
176
177
178
179
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
180
            linear_method=linear_method,
Pierre Stock's avatar
Pierre Stock committed
181
        )
182
        self.o_proj = RowParallelLinear(
Pierre Stock's avatar
Pierre Stock committed
183
184
185
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
186
            linear_method=linear_method,
Pierre Stock's avatar
Pierre Stock committed
187
188
189
190
191
192
        )
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position,
            base=int(self.rope_theta),
193
            is_neox_style=True,
Pierre Stock's avatar
Pierre Stock committed
194
        )
195
        self.attn = Attention(
Pierre Stock's avatar
Pierre Stock committed
196
197
198
199
200
201
202
203
204
205
206
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            sliding_window=self.sliding_window,
        )

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
207
208
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Pierre Stock's avatar
Pierre Stock committed
209
    ) -> torch.Tensor:
210
        qkv, _ = self.qkv_proj(hidden_states)
Pierre Stock's avatar
Pierre Stock committed
211
212
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
213
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
214
        output, _ = self.o_proj(attn_output)
Pierre Stock's avatar
Pierre Stock committed
215
216
217
218
219
220
221
        return output


class MixtralDecoderLayer(nn.Module):

    def __init__(
        self,
222
        config: MixtralConfig,
223
        linear_method: Optional[LinearMethodBase] = None,
Pierre Stock's avatar
Pierre Stock committed
224
225
226
227
228
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        # Requires transformers > 4.32.0
        rope_theta = getattr(config, "rope_theta", 10000)
229
        self.self_attn = MixtralAttention(
Pierre Stock's avatar
Pierre Stock committed
230
231
232
233
234
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            max_position=config.max_position_embeddings,
            num_kv_heads=config.num_key_value_heads,
            rope_theta=rope_theta,
235
236
            sliding_window=config.sliding_window,
            linear_method=linear_method)
Philipp Moritz's avatar
Philipp Moritz committed
237
238
239
240
241
        self.block_sparse_moe = MixtralMoE(
            num_experts=config.num_local_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            intermediate_size=config.intermediate_size)
242
243
244
245
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)
Pierre Stock's avatar
Pierre Stock committed
246
247
248
249

    def forward(
        self,
        positions: torch.Tensor,
250
        hidden_states: torch.Tensor,
251
252
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
253
        residual: Optional[torch.Tensor],
Pierre Stock's avatar
Pierre Stock committed
254
    ) -> torch.Tensor:
255
256
257
258
259
260
261
262
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
        hidden_states = self.self_attn(
Pierre Stock's avatar
Pierre Stock committed
263
            positions=positions,
264
            hidden_states=hidden_states,
Pierre Stock's avatar
Pierre Stock committed
265
            kv_cache=kv_cache,
266
            attn_metadata=attn_metadata,
Pierre Stock's avatar
Pierre Stock committed
267
268
        )

269
270
271
272
273
        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
        hidden_states = self.block_sparse_moe(hidden_states)
        return hidden_states, residual
Pierre Stock's avatar
Pierre Stock committed
274

275
276

class MixtralModel(nn.Module):
Pierre Stock's avatar
Pierre Stock committed
277
278
279

    def __init__(
        self,
280
        config: MixtralConfig,
Pierre Stock's avatar
Pierre Stock committed
281
        linear_method: Optional[LinearMethodBase] = None,
282
        lora_config: Optional[LoRAConfig] = None,
Pierre Stock's avatar
Pierre Stock committed
283
284
285
    ) -> None:
        super().__init__()
        self.padding_idx = config.pad_token_id
286
287
288
289
        lora_vocab = (lora_config.lora_extra_vocab_size *
                      (lora_config.max_loras or 1)) if lora_config else 0
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size
290
291

        self.embed_tokens = VocabParallelEmbedding(
292
            self.vocab_size,
Pierre Stock's avatar
Pierre Stock committed
293
            config.hidden_size,
294
            org_num_embeddings=config.vocab_size,
Pierre Stock's avatar
Pierre Stock committed
295
296
        )
        self.layers = nn.ModuleList([
297
            MixtralDecoderLayer(config, linear_method=linear_method)
Pierre Stock's avatar
Pierre Stock committed
298
299
            for _ in range(config.num_hidden_layers)
        ])
300
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Pierre Stock's avatar
Pierre Stock committed
301
302
303
304
305

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
306
307
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
308
    ) -> torch.Tensor:
309
310
        hidden_states = self.embed_tokens(input_ids)
        residual = None
Pierre Stock's avatar
Pierre Stock committed
311
312
        for i in range(len(self.layers)):
            layer = self.layers[i]
313
            hidden_states, residual = layer(positions, hidden_states,
314
                                            kv_caches[i], attn_metadata,
315
                                            residual)
316
317
318
319
320
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states


class MixtralForCausalLM(nn.Module):
321
322
    fall_back_to_pt_during_load = False

Terry's avatar
Terry committed
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
    }

    # LoRA specific attributes
    supported_lora_modules = [
        "qkv_proj",
        "o_proj",
        "embed_tokens",
        "lm_head",
    ]
    embedding_modules = {
        "embed_tokens": "input_embeddings",
        "lm_head": "output_embeddings",
    }
    embedding_padding_modules = ["lm_head"]
343
344
345
346
347

    def __init__(
        self,
        config: MixtralConfig,
        linear_method: Optional[LinearMethodBase] = None,
Terry's avatar
Terry committed
348
        lora_config: Optional[LoRAConfig] = None,
349
350
351
352
    ) -> None:
        super().__init__()
        self.config = config
        self.linear_method = linear_method
353
354
355
        self.model = MixtralModel(config,
                                  linear_method,
                                  lora_config=lora_config)
Terry's avatar
Terry committed
356
357
358
359
360
361
362
363
364
365
366
367
        self.unpadded_vocab_size = config.vocab_size
        if lora_config:
            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
        self.lm_head = ParallelLMHead(
            self.unpadded_vocab_size,
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
            padding_size=DEFAULT_VOCAB_PADDING_SIZE
            # We need bigger padding if using lora for kernel
            # compatibility
            if not lora_config else lora_config.lora_vocab_padding_size,
        )
368
369
370
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size)
        self.sampler = Sampler()
371
372
373
374
375

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
376
377
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
378
379
    ) -> torch.Tensor:
        hidden_states = self.model(input_ids, positions, kv_caches,
380
                                   attn_metadata)
Pierre Stock's avatar
Pierre Stock committed
381
382
        return hidden_states

383
384
385
386
387
388
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head.weight, hidden_states,
                                       sampling_metadata)
        return logits

Pierre Stock's avatar
Pierre Stock committed
389
390
    def sample(
        self,
391
        logits: Optional[torch.Tensor],
Pierre Stock's avatar
Pierre Stock committed
392
        sampling_metadata: SamplingMetadata,
393
    ) -> Optional[SamplerOutput]:
394
        next_tokens = self.sampler(logits, sampling_metadata)
Pierre Stock's avatar
Pierre Stock committed
395
396
        return next_tokens

397
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
Pierre Stock's avatar
Pierre Stock committed
398
399
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
400
401
402
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
Pierre Stock's avatar
Pierre Stock committed
403
        ]
404

Philipp Moritz's avatar
Philipp Moritz committed
405
406
407
408
409
410
411
412
        expert_params_mapping = [
            # (param_name, weight_name, expert_id)
            ("ws" if weight_name in ["w1", "w3"] else "w2s",
             f"experts.{expert_id}.{weight_name}.weight", expert_id)
            for expert_id in range(self.config.num_local_experts)
            for weight_name in ["w1", "w2", "w3"]
        ]

Pierre Stock's avatar
Pierre Stock committed
413
        params_dict = dict(self.named_parameters())
414
        for name, loaded_weight in weights:
Pierre Stock's avatar
Pierre Stock committed
415
416
            if "rotary_emb.inv_freq" in name:
                continue
Philipp Moritz's avatar
Philipp Moritz committed
417

Pierre Stock's avatar
Pierre Stock committed
418
419
420
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
421
422
423
424
425
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
Pierre Stock's avatar
Pierre Stock committed
426
427
428
429
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
Philipp Moritz's avatar
Philipp Moritz committed
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
                for param_name, weight_name, expert_id in expert_params_mapping:
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
                    param = params_dict[name]
                    weight_loader = param.weight_loader
                    weight_loader(param,
                                  loaded_weight,
                                  weight_name,
                                  expert_id=expert_id)
                    break
                else:
                    # Skip loading extra bias for GPTQ models.
                    if name.endswith(".bias") and name not in params_dict:
                        continue
                    param = params_dict[name]
                    weight_loader = getattr(param, "weight_loader",
                                            default_weight_loader)
                    weight_loader(param, loaded_weight)