mixtral.py 17.8 KB
Newer Older
Pierre Stock's avatar
Pierre Stock committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Mixtral model."""
Woosuk Kwon's avatar
Woosuk Kwon committed
24
from typing import List, Optional, Tuple
Pierre Stock's avatar
Pierre Stock committed
25
26
27

import torch
from torch import nn
28
from transformers import MixtralConfig
Pierre Stock's avatar
Pierre Stock committed
29

Terry's avatar
Terry committed
30
from vllm.config import LoRAConfig
Pierre Stock's avatar
Pierre Stock committed
31
from vllm.model_executor.input_metadata import InputMetadata
32
from vllm.model_executor.layers.attention import Attention
Philipp Moritz's avatar
Philipp Moritz committed
33
from vllm.model_executor.layers.fused_moe import fused_moe
Pierre Stock's avatar
Pierre Stock committed
34
35
36
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
                                               QKVParallelLinear,
Philipp Moritz's avatar
Philipp Moritz committed
37
                                               ReplicatedLinear,
Pierre Stock's avatar
Pierre Stock committed
38
39
                                               RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
40
from vllm.model_executor.layers.logits_processor import LogitsProcessor
Pierre Stock's avatar
Pierre Stock committed
41
42
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
Terry's avatar
Terry committed
43
    VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE)
Pierre Stock's avatar
Pierre Stock committed
44
45
46
47
48
from vllm.model_executor.parallel_utils.communication_op import (
    tensor_model_parallel_all_reduce)
from vllm.model_executor.parallel_utils.parallel_state import (
    get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
Philipp Moritz's avatar
Philipp Moritz committed
49
from vllm.model_executor.utils import set_weight_attrs
Pierre Stock's avatar
Pierre Stock committed
50
51
52
53
54
55
56
from vllm.model_executor.weight_utils import (default_weight_loader,
                                              hf_model_weights_iterator)
from vllm.sequence import SamplerOutput

KVCache = Tuple[torch.Tensor, torch.Tensor]


Philipp Moritz's avatar
Philipp Moritz committed
57
58
59
60
61
62
63
64
class MixtralMoE(nn.Module):
    """A tensor-parallel MoE implementation for Mixtral that shards each expert
    across all ranks.

    Each expert's weights are sharded across all ranks and a fused MoE
    kernel is used for the forward pass, and finally we reduce the outputs
    across ranks.
    """
65
66
67
68

    def __init__(
        self,
        num_experts: int,
Philipp Moritz's avatar
Philipp Moritz committed
69
        top_k: int,
70
71
        hidden_size: int,
        intermediate_size: int,
Philipp Moritz's avatar
Philipp Moritz committed
72
        params_dtype: Optional[torch.dtype] = None,
73
        tp_size: Optional[int] = None,
Philipp Moritz's avatar
Philipp Moritz committed
74
    ):
75
        super().__init__()
76
        self.tp_size = tp_size or get_tensor_model_parallel_world_size()
Philipp Moritz's avatar
Philipp Moritz committed
77
78
79
        self.num_total_experts = num_experts
        self.top_k = top_k
        self.hidden_size = hidden_size
80
        self.intermediate_size = intermediate_size // self.tp_size
81

Philipp Moritz's avatar
Philipp Moritz committed
82
83
84
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype
85

Philipp Moritz's avatar
Philipp Moritz committed
86
        self.gate = ReplicatedLinear(self.hidden_size,
87
88
                                     self.num_total_experts,
                                     bias=False,
Philipp Moritz's avatar
Philipp Moritz committed
89
                                     params_dtype=self.params_dtype,
CHU Tianxiang's avatar
CHU Tianxiang committed
90
                                     linear_method=None)
91

Philipp Moritz's avatar
Philipp Moritz committed
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
        self.ws = nn.Parameter(
            torch.empty(self.num_total_experts,
                        2 * self.intermediate_size,
                        self.hidden_size,
                        device="cuda",
                        dtype=self.params_dtype))
        self.w2s = nn.Parameter(
            torch.empty(self.num_total_experts,
                        self.hidden_size,
                        self.intermediate_size,
                        device="cuda",
                        dtype=self.params_dtype))

        set_weight_attrs(self.ws, {
            "weight_loader": self.weight_loader,
        })
        set_weight_attrs(self.w2s, {
            "weight_loader": self.weight_loader,
        })

    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
                      weight_name: str, expert_id: int):
        tp_rank = get_tensor_model_parallel_rank()
        param_data = param.data
        shard_size = self.intermediate_size
        shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
        if weight_name.endswith("w1.weight"):
            param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
        if weight_name.endswith("w3.weight"):
            param_data[expert_id,
                       shard_size:2 * shard_size, :] = loaded_weight[shard, :]
        if weight_name.endswith("w2.weight"):
            param_data[expert_id, :, :] = loaded_weight[:, shard]

126
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
Philipp Moritz's avatar
Philipp Moritz committed
127
128
        batch_size, sequence_length, hidden_size = hidden_states.shape
        hidden_states = hidden_states.view(-1, self.hidden_size)
129
130
        # router_logits: (batch * sequence_length, n_experts)
        router_logits, _ = self.gate(hidden_states)
Philipp Moritz's avatar
Philipp Moritz committed
131
132
133
        final_hidden_states = fused_moe(hidden_states,
                                        self.ws,
                                        self.w2s,
134
135
136
                                        router_logits,
                                        self.top_k,
                                        renormalize=True,
Philipp Moritz's avatar
Philipp Moritz committed
137
138
                                        inplace=True)

139
140
141
        if self.tp_size > 1:
            final_hidden_states = tensor_model_parallel_all_reduce(
                final_hidden_states)
142

Philipp Moritz's avatar
Philipp Moritz committed
143
144
        return final_hidden_states.view(batch_size, sequence_length,
                                        hidden_size)
Pierre Stock's avatar
Pierre Stock committed
145
146
147
148
149
150
151
152
153
154


class MixtralAttention(nn.Module):

    def __init__(self,
                 hidden_size: int,
                 num_heads: int,
                 num_kv_heads: int,
                 max_position: int = 4096 * 32,
                 rope_theta: float = 10000,
155
                 linear_method: Optional[LinearMethodBase] = None,
Pierre Stock's avatar
Pierre Stock committed
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
                 sliding_window: Optional[int] = None) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta
        self.sliding_window = sliding_window

180
        self.qkv_proj = QKVParallelLinear(
Pierre Stock's avatar
Pierre Stock committed
181
182
183
184
185
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
186
            linear_method=linear_method,
Pierre Stock's avatar
Pierre Stock committed
187
        )
188
        self.o_proj = RowParallelLinear(
Pierre Stock's avatar
Pierre Stock committed
189
190
191
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
192
            linear_method=linear_method,
Pierre Stock's avatar
Pierre Stock committed
193
194
195
196
197
198
        )
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position,
            base=int(self.rope_theta),
199
            is_neox_style=True,
Pierre Stock's avatar
Pierre Stock committed
200
        )
201
        self.attn = Attention(
Pierre Stock's avatar
Pierre Stock committed
202
203
204
205
206
207
208
209
210
211
212
213
214
215
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            sliding_window=self.sliding_window,
        )

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
    ) -> torch.Tensor:
216
        qkv, _ = self.qkv_proj(hidden_states)
Pierre Stock's avatar
Pierre Stock committed
217
218
219
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
        k_cache, v_cache = kv_cache
220
        attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
221
        output, _ = self.o_proj(attn_output)
Pierre Stock's avatar
Pierre Stock committed
222
223
224
225
226
227
228
        return output


class MixtralDecoderLayer(nn.Module):

    def __init__(
        self,
229
        config: MixtralConfig,
230
        linear_method: Optional[LinearMethodBase] = None,
Pierre Stock's avatar
Pierre Stock committed
231
232
233
234
235
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        # Requires transformers > 4.32.0
        rope_theta = getattr(config, "rope_theta", 10000)
236
        self.self_attn = MixtralAttention(
Pierre Stock's avatar
Pierre Stock committed
237
238
239
240
241
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            max_position=config.max_position_embeddings,
            num_kv_heads=config.num_key_value_heads,
            rope_theta=rope_theta,
242
243
            sliding_window=config.sliding_window,
            linear_method=linear_method)
Philipp Moritz's avatar
Philipp Moritz committed
244
245
246
247
248
        self.block_sparse_moe = MixtralMoE(
            num_experts=config.num_local_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            intermediate_size=config.intermediate_size)
249
250
251
252
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)
Pierre Stock's avatar
Pierre Stock committed
253
254
255
256

    def forward(
        self,
        positions: torch.Tensor,
257
        hidden_states: torch.Tensor,
Pierre Stock's avatar
Pierre Stock committed
258
259
        kv_cache: KVCache,
        input_metadata: InputMetadata,
260
        residual: Optional[torch.Tensor],
Pierre Stock's avatar
Pierre Stock committed
261
    ) -> torch.Tensor:
262
263
264
265
266
267
268
269
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
        hidden_states = self.self_attn(
Pierre Stock's avatar
Pierre Stock committed
270
            positions=positions,
271
            hidden_states=hidden_states,
Pierre Stock's avatar
Pierre Stock committed
272
273
274
275
            kv_cache=kv_cache,
            input_metadata=input_metadata,
        )

276
277
278
279
280
        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
        hidden_states = self.block_sparse_moe(hidden_states)
        return hidden_states, residual
Pierre Stock's avatar
Pierre Stock committed
281

282
283

class MixtralModel(nn.Module):
Pierre Stock's avatar
Pierre Stock committed
284
285
286

    def __init__(
        self,
287
        config: MixtralConfig,
Pierre Stock's avatar
Pierre Stock committed
288
        linear_method: Optional[LinearMethodBase] = None,
289
        lora_config: Optional[LoRAConfig] = None,
Pierre Stock's avatar
Pierre Stock committed
290
291
292
    ) -> None:
        super().__init__()
        self.padding_idx = config.pad_token_id
293
294
295
296
        lora_vocab = (lora_config.lora_extra_vocab_size *
                      (lora_config.max_loras or 1)) if lora_config else 0
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size
297
298

        self.embed_tokens = VocabParallelEmbedding(
299
            self.vocab_size,
Pierre Stock's avatar
Pierre Stock committed
300
            config.hidden_size,
301
            org_num_embeddings=config.vocab_size,
Pierre Stock's avatar
Pierre Stock committed
302
303
        )
        self.layers = nn.ModuleList([
304
            MixtralDecoderLayer(config, linear_method=linear_method)
Pierre Stock's avatar
Pierre Stock committed
305
306
            for _ in range(config.num_hidden_layers)
        ])
307
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Pierre Stock's avatar
Pierre Stock committed
308
309
310
311
312
313
314

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
315
    ) -> torch.Tensor:
316
317
        hidden_states = self.embed_tokens(input_ids)
        residual = None
Pierre Stock's avatar
Pierre Stock committed
318
319
        for i in range(len(self.layers)):
            layer = self.layers[i]
320
321
            hidden_states, residual = layer(positions, hidden_states,
                                            kv_caches[i], input_metadata,
322
                                            residual)
323
324
325
326
327
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states


class MixtralForCausalLM(nn.Module):
Terry's avatar
Terry committed
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
    }

    # LoRA specific attributes
    supported_lora_modules = [
        "qkv_proj",
        "o_proj",
        "embed_tokens",
        "lm_head",
    ]
    embedding_modules = {
        "embed_tokens": "input_embeddings",
        "lm_head": "output_embeddings",
    }
    embedding_padding_modules = ["lm_head"]
348
349
350
351
352

    def __init__(
        self,
        config: MixtralConfig,
        linear_method: Optional[LinearMethodBase] = None,
Terry's avatar
Terry committed
353
        lora_config: Optional[LoRAConfig] = None,
354
355
356
357
    ) -> None:
        super().__init__()
        self.config = config
        self.linear_method = linear_method
358
359
360
        self.model = MixtralModel(config,
                                  linear_method,
                                  lora_config=lora_config)
Terry's avatar
Terry committed
361
362
363
364
365
366
367
368
369
370
371
372
        self.unpadded_vocab_size = config.vocab_size
        if lora_config:
            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
        self.lm_head = ParallelLMHead(
            self.unpadded_vocab_size,
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
            padding_size=DEFAULT_VOCAB_PADDING_SIZE
            # We need bigger padding if using lora for kernel
            # compatibility
            if not lora_config else lora_config.lora_vocab_padding_size,
        )
373
374
375
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size)
        self.sampler = Sampler()
376
377
378
379
380
381
382
383
384

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
    ) -> torch.Tensor:
        hidden_states = self.model(input_ids, positions, kv_caches,
385
                                   input_metadata)
Pierre Stock's avatar
Pierre Stock committed
386
387
        return hidden_states

388
389
390
391
392
393
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head.weight, hidden_states,
                                       sampling_metadata)
        return logits

Pierre Stock's avatar
Pierre Stock committed
394
395
    def sample(
        self,
396
        logits: Optional[torch.Tensor],
Pierre Stock's avatar
Pierre Stock committed
397
        sampling_metadata: SamplingMetadata,
398
    ) -> Optional[SamplerOutput]:
399
        next_tokens = self.sampler(logits, sampling_metadata)
Pierre Stock's avatar
Pierre Stock committed
400
401
402
403
404
405
406
407
408
        return next_tokens

    def load_weights(self,
                     model_name_or_path: str,
                     cache_dir: Optional[str] = None,
                     load_format: str = "auto",
                     revision: Optional[str] = None):
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
409
410
411
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
Pierre Stock's avatar
Pierre Stock committed
412
        ]
413

Philipp Moritz's avatar
Philipp Moritz committed
414
415
416
417
418
419
420
421
        expert_params_mapping = [
            # (param_name, weight_name, expert_id)
            ("ws" if weight_name in ["w1", "w3"] else "w2s",
             f"experts.{expert_id}.{weight_name}.weight", expert_id)
            for expert_id in range(self.config.num_local_experts)
            for weight_name in ["w1", "w2", "w3"]
        ]

Pierre Stock's avatar
Pierre Stock committed
422
423
        params_dict = dict(self.named_parameters())
        for name, loaded_weight in hf_model_weights_iterator(
Roy's avatar
Roy committed
424
425
426
427
428
                model_name_or_path,
                cache_dir,
                load_format,
                revision,
                fall_back_to_pt=False):
Pierre Stock's avatar
Pierre Stock committed
429
430
            if "rotary_emb.inv_freq" in name:
                continue
Philipp Moritz's avatar
Philipp Moritz committed
431

Pierre Stock's avatar
Pierre Stock committed
432
433
434
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
435
436
437
438
439
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
Pierre Stock's avatar
Pierre Stock committed
440
441
442
443
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
Philipp Moritz's avatar
Philipp Moritz committed
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
                for param_name, weight_name, expert_id in expert_params_mapping:
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
                    param = params_dict[name]
                    weight_loader = param.weight_loader
                    weight_loader(param,
                                  loaded_weight,
                                  weight_name,
                                  expert_id=expert_id)
                    break
                else:
                    # Skip loading extra bias for GPTQ models.
                    if name.endswith(".bias") and name not in params_dict:
                        continue
                    param = params_dict[name]
                    weight_loader = getattr(param, "weight_loader",
                                            default_weight_loader)
                    weight_loader(param, loaded_weight)