mixtral.py 16.3 KB
Newer Older
Pierre Stock's avatar
Pierre Stock committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Mixtral model."""
Woosuk Kwon's avatar
Woosuk Kwon committed
24
from typing import List, Optional, Tuple
Pierre Stock's avatar
Pierre Stock committed
25
26
27
28
29

import torch
import torch.nn.functional as F

from torch import nn
30
from transformers import MixtralConfig
Pierre Stock's avatar
Pierre Stock committed
31
32
33

from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.attention import PagedAttention
Philipp Moritz's avatar
Philipp Moritz committed
34
from vllm.model_executor.layers.fused_moe import fused_moe
Pierre Stock's avatar
Pierre Stock committed
35
36
37
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
                                               QKVParallelLinear,
Philipp Moritz's avatar
Philipp Moritz committed
38
                                               ReplicatedLinear,
Pierre Stock's avatar
Pierre Stock committed
39
40
41
42
43
44
45
46
47
48
                                               RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.communication_op import (
    tensor_model_parallel_all_reduce)
from vllm.model_executor.parallel_utils.parallel_state import (
    get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
Philipp Moritz's avatar
Philipp Moritz committed
49
from vllm.model_executor.utils import set_weight_attrs
Pierre Stock's avatar
Pierre Stock committed
50
51
52
53
54
55
56
from vllm.model_executor.weight_utils import (default_weight_loader,
                                              hf_model_weights_iterator)
from vllm.sequence import SamplerOutput

KVCache = Tuple[torch.Tensor, torch.Tensor]


Philipp Moritz's avatar
Philipp Moritz committed
57
58
59
60
61
62
63
64
class MixtralMoE(nn.Module):
    """A tensor-parallel MoE implementation for Mixtral that shards each expert
    across all ranks.

    Each expert's weights are sharded across all ranks and a fused MoE
    kernel is used for the forward pass, and finally we reduce the outputs
    across ranks.
    """
65
66
67
68

    def __init__(
        self,
        num_experts: int,
Philipp Moritz's avatar
Philipp Moritz committed
69
        top_k: int,
70
71
        hidden_size: int,
        intermediate_size: int,
Philipp Moritz's avatar
Philipp Moritz committed
72
        params_dtype: Optional[torch.dtype] = None,
73
        tp_size: Optional[int] = None,
Philipp Moritz's avatar
Philipp Moritz committed
74
    ):
75
        super().__init__()
76
        self.tp_size = tp_size or get_tensor_model_parallel_world_size()
Philipp Moritz's avatar
Philipp Moritz committed
77
78
79
        self.num_total_experts = num_experts
        self.top_k = top_k
        self.hidden_size = hidden_size
80
        self.intermediate_size = intermediate_size // self.tp_size
81

Philipp Moritz's avatar
Philipp Moritz committed
82
83
84
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype
85

Philipp Moritz's avatar
Philipp Moritz committed
86
        self.gate = ReplicatedLinear(self.hidden_size,
87
88
                                     self.num_total_experts,
                                     bias=False,
Philipp Moritz's avatar
Philipp Moritz committed
89
                                     params_dtype=self.params_dtype,
CHU Tianxiang's avatar
CHU Tianxiang committed
90
                                     linear_method=None)
91

Philipp Moritz's avatar
Philipp Moritz committed
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
        self.ws = nn.Parameter(
            torch.empty(self.num_total_experts,
                        2 * self.intermediate_size,
                        self.hidden_size,
                        device="cuda",
                        dtype=self.params_dtype))
        self.w2s = nn.Parameter(
            torch.empty(self.num_total_experts,
                        self.hidden_size,
                        self.intermediate_size,
                        device="cuda",
                        dtype=self.params_dtype))

        set_weight_attrs(self.ws, {
            "weight_loader": self.weight_loader,
        })
        set_weight_attrs(self.w2s, {
            "weight_loader": self.weight_loader,
        })

    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
                      weight_name: str, expert_id: int):
        tp_rank = get_tensor_model_parallel_rank()
        param_data = param.data
        shard_size = self.intermediate_size
        shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
        if weight_name.endswith("w1.weight"):
            param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
        if weight_name.endswith("w3.weight"):
            param_data[expert_id,
                       shard_size:2 * shard_size, :] = loaded_weight[shard, :]
        if weight_name.endswith("w2.weight"):
            param_data[expert_id, :, :] = loaded_weight[:, shard]

126
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
Philipp Moritz's avatar
Philipp Moritz committed
127
128
        batch_size, sequence_length, hidden_size = hidden_states.shape
        hidden_states = hidden_states.view(-1, self.hidden_size)
129
130
131
132
133
134
135
136
137
        # router_logits: (batch * sequence_length, n_experts)
        router_logits, _ = self.gate(hidden_states)

        routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
        routing_weights, selected_experts = torch.topk(routing_weights,
                                                       self.top_k,
                                                       dim=-1)
        routing_weights /= routing_weights.sum(dim=-1, keepdim=True)

Philipp Moritz's avatar
Philipp Moritz committed
138
139
140
141
142
143
144
        final_hidden_states = fused_moe(hidden_states,
                                        self.ws,
                                        self.w2s,
                                        routing_weights,
                                        selected_experts,
                                        inplace=True)

145
146
147
        if self.tp_size > 1:
            final_hidden_states = tensor_model_parallel_all_reduce(
                final_hidden_states)
148

Philipp Moritz's avatar
Philipp Moritz committed
149
150
        return final_hidden_states.view(batch_size, sequence_length,
                                        hidden_size)
Pierre Stock's avatar
Pierre Stock committed
151
152
153
154
155
156
157
158
159
160


class MixtralAttention(nn.Module):

    def __init__(self,
                 hidden_size: int,
                 num_heads: int,
                 num_kv_heads: int,
                 max_position: int = 4096 * 32,
                 rope_theta: float = 10000,
161
                 linear_method: Optional[LinearMethodBase] = None,
Pierre Stock's avatar
Pierre Stock committed
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
                 sliding_window: Optional[int] = None) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta
        self.sliding_window = sliding_window

186
        self.qkv_proj = QKVParallelLinear(
Pierre Stock's avatar
Pierre Stock committed
187
188
189
190
191
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
192
            linear_method=linear_method,
Pierre Stock's avatar
Pierre Stock committed
193
        )
194
        self.o_proj = RowParallelLinear(
Pierre Stock's avatar
Pierre Stock committed
195
196
197
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
198
            linear_method=linear_method,
Pierre Stock's avatar
Pierre Stock committed
199
200
201
202
203
204
        )
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position,
            base=int(self.rope_theta),
205
            is_neox_style=True,
Pierre Stock's avatar
Pierre Stock committed
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
        )
        self.attn = PagedAttention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            sliding_window=self.sliding_window,
        )

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
    ) -> torch.Tensor:
222
        qkv, _ = self.qkv_proj(hidden_states)
Pierre Stock's avatar
Pierre Stock committed
223
224
225
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
        k_cache, v_cache = kv_cache
226
        attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
227
        output, _ = self.o_proj(attn_output)
Pierre Stock's avatar
Pierre Stock committed
228
229
230
231
232
233
234
        return output


class MixtralDecoderLayer(nn.Module):

    def __init__(
        self,
235
        config: MixtralConfig,
236
        linear_method: Optional[LinearMethodBase] = None,
Pierre Stock's avatar
Pierre Stock committed
237
238
239
240
241
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        # Requires transformers > 4.32.0
        rope_theta = getattr(config, "rope_theta", 10000)
242
        self.self_attn = MixtralAttention(
Pierre Stock's avatar
Pierre Stock committed
243
244
245
246
247
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            max_position=config.max_position_embeddings,
            num_kv_heads=config.num_key_value_heads,
            rope_theta=rope_theta,
248
249
            sliding_window=config.sliding_window,
            linear_method=linear_method)
Philipp Moritz's avatar
Philipp Moritz committed
250
251
252
253
254
        self.block_sparse_moe = MixtralMoE(
            num_experts=config.num_local_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            intermediate_size=config.intermediate_size)
255
256
257
258
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)
Pierre Stock's avatar
Pierre Stock committed
259
260
261
262

    def forward(
        self,
        positions: torch.Tensor,
263
        hidden_states: torch.Tensor,
Pierre Stock's avatar
Pierre Stock committed
264
265
        kv_cache: KVCache,
        input_metadata: InputMetadata,
266
        residual: Optional[torch.Tensor],
Pierre Stock's avatar
Pierre Stock committed
267
    ) -> torch.Tensor:
268
269
270
271
272
273
274
275
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
        hidden_states = self.self_attn(
Pierre Stock's avatar
Pierre Stock committed
276
            positions=positions,
277
            hidden_states=hidden_states,
Pierre Stock's avatar
Pierre Stock committed
278
279
280
281
            kv_cache=kv_cache,
            input_metadata=input_metadata,
        )

282
283
284
285
286
        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
        hidden_states = self.block_sparse_moe(hidden_states)
        return hidden_states, residual
Pierre Stock's avatar
Pierre Stock committed
287

288
289

class MixtralModel(nn.Module):
Pierre Stock's avatar
Pierre Stock committed
290
291
292

    def __init__(
        self,
293
        config: MixtralConfig,
Pierre Stock's avatar
Pierre Stock committed
294
295
296
297
298
        linear_method: Optional[LinearMethodBase] = None,
    ) -> None:
        super().__init__()
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size
299
300

        self.embed_tokens = VocabParallelEmbedding(
Pierre Stock's avatar
Pierre Stock committed
301
302
303
304
            config.vocab_size,
            config.hidden_size,
        )
        self.layers = nn.ModuleList([
305
            MixtralDecoderLayer(config, linear_method=linear_method)
Pierre Stock's avatar
Pierre Stock committed
306
307
            for _ in range(config.num_hidden_layers)
        ])
308
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Pierre Stock's avatar
Pierre Stock committed
309
310
311
312
313
314
315

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
316
    ) -> torch.Tensor:
317
318
        hidden_states = self.embed_tokens(input_ids)
        residual = None
Pierre Stock's avatar
Pierre Stock committed
319
320
        for i in range(len(self.layers)):
            layer = self.layers[i]
321
322
            hidden_states, residual = layer(positions, hidden_states,
                                            kv_caches[i], input_metadata,
323
                                            residual)
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states


class MixtralForCausalLM(nn.Module):

    def __init__(
        self,
        config: MixtralConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ) -> None:
        super().__init__()
        self.config = config
        self.linear_method = linear_method
        self.model = MixtralModel(config, linear_method)
        self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
        self.sampler = Sampler(config.vocab_size)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
    ) -> torch.Tensor:
        hidden_states = self.model(input_ids, positions, kv_caches,
350
                                   input_metadata)
Pierre Stock's avatar
Pierre Stock committed
351
352
353
354
355
356
        return hidden_states

    def sample(
        self,
        hidden_states: Optional[torch.Tensor],
        sampling_metadata: SamplingMetadata,
357
    ) -> Optional[SamplerOutput]:
358
        next_tokens = self.sampler(self.lm_head.weight, hidden_states,
Pierre Stock's avatar
Pierre Stock committed
359
360
361
362
363
364
365
366
367
368
                                   sampling_metadata)
        return next_tokens

    def load_weights(self,
                     model_name_or_path: str,
                     cache_dir: Optional[str] = None,
                     load_format: str = "auto",
                     revision: Optional[str] = None):
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
369
370
371
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
Pierre Stock's avatar
Pierre Stock committed
372
        ]
373

Philipp Moritz's avatar
Philipp Moritz committed
374
375
376
377
378
379
380
381
        expert_params_mapping = [
            # (param_name, weight_name, expert_id)
            ("ws" if weight_name in ["w1", "w3"] else "w2s",
             f"experts.{expert_id}.{weight_name}.weight", expert_id)
            for expert_id in range(self.config.num_local_experts)
            for weight_name in ["w1", "w2", "w3"]
        ]

Pierre Stock's avatar
Pierre Stock committed
382
383
        params_dict = dict(self.named_parameters())
        for name, loaded_weight in hf_model_weights_iterator(
Roy's avatar
Roy committed
384
385
386
387
388
                model_name_or_path,
                cache_dir,
                load_format,
                revision,
                fall_back_to_pt=False):
Pierre Stock's avatar
Pierre Stock committed
389
390
            if "rotary_emb.inv_freq" in name:
                continue
Philipp Moritz's avatar
Philipp Moritz committed
391

Pierre Stock's avatar
Pierre Stock committed
392
393
394
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
395
396
397
398
399
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
Pierre Stock's avatar
Pierre Stock committed
400
401
402
403
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
Philipp Moritz's avatar
Philipp Moritz committed
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
                for param_name, weight_name, expert_id in expert_params_mapping:
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
                    param = params_dict[name]
                    weight_loader = param.weight_loader
                    weight_loader(param,
                                  loaded_weight,
                                  weight_name,
                                  expert_id=expert_id)
                    break
                else:
                    # Skip loading extra bias for GPTQ models.
                    if name.endswith(".bias") and name not in params_dict:
                        continue
                    param = params_dict[name]
                    weight_loader = getattr(param, "weight_loader",
                                            default_weight_loader)
                    weight_loader(param, loaded_weight)