mixtral.py 13.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
# Adapted from
Lianmin Zheng's avatar
Lianmin Zheng committed
17
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
Lianmin Zheng's avatar
Lianmin Zheng committed
18
"""Inference-only Mixtral model."""
19
from typing import Iterable, Optional, Tuple
Lianmin Zheng's avatar
Lianmin Zheng committed
20
21
22
23

import torch
from torch import nn
from transformers import MixtralConfig
Lianmin Zheng's avatar
Lianmin Zheng committed
24
from vllm.config import CacheConfig
25
26
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.fused_moe import FusedMoE
Lianmin Zheng's avatar
Lianmin Zheng committed
27
28
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
29
    DEFAULT_VOCAB_PADDING_SIZE,
Lianmin Zheng's avatar
Lianmin Zheng committed
30
31
32
    ParallelLMHead,
    VocabParallelEmbedding,
)
33
34
from vllm.model_executor.model_loader.weight_utils import default_weight_loader

35
from sglang.srt.layers.layernorm import RMSNorm
36
37
38
39
40
from sglang.srt.layers.linear import (
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
Liangsheng Yin's avatar
Liangsheng Yin committed
41
from sglang.srt.layers.logits_processor import LogitsProcessor
42
from sglang.srt.layers.quantization.base_config import QuantizationConfig
Liangsheng Yin's avatar
Liangsheng Yin committed
43
from sglang.srt.layers.radix_attention import RadixAttention
44
45
from sglang.srt.layers.torchao_utils import apply_torchao_config_
from sglang.srt.managers.schedule_batch import global_server_args_dict
46
from sglang.srt.model_executor.forward_batch_info import InputMetadata
Liangsheng Yin's avatar
Liangsheng Yin committed
47

Lianmin Zheng's avatar
Lianmin Zheng committed
48

Lianmin Zheng's avatar
Lianmin Zheng committed
49
50
51
class MixtralMoE(nn.Module):
    """A tensor-parallel MoE implementation for Mixtral that shards each expert
    across all ranks.
Lianmin Zheng's avatar
Lianmin Zheng committed
52

Lianmin Zheng's avatar
Lianmin Zheng committed
53
54
55
56
    Each expert's weights are sharded across all ranks and a fused MoE
    kernel is used for the forward pass, and finally we reduce the outputs
    across ranks.
    """
Lianmin Zheng's avatar
Lianmin Zheng committed
57
58
59

    def __init__(
        self,
Lianmin Zheng's avatar
Lianmin Zheng committed
60
61
62
63
64
        num_experts: int,
        top_k: int,
        hidden_size: int,
        intermediate_size: int,
        params_dtype: Optional[torch.dtype] = None,
65
        quant_config: Optional[QuantizationConfig] = None,
66
67
        tp_size: Optional[int] = None,
        prefix: str = "",
Lianmin Zheng's avatar
Lianmin Zheng committed
68
69
    ):
        super().__init__()
Lianmin Zheng's avatar
Lianmin Zheng committed
70
71
72
        self.hidden_size = hidden_size

        # Gate always runs at half / full precision for now.
73
        self.gate = ReplicatedLinear(
74
75
            hidden_size,
            num_experts,
76
            bias=False,
77
            params_dtype=params_dtype,
78
            quant_config=None,
79
            prefix=f"{prefix}.gate",
80
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
81

82
83
84
85
86
87
88
89
90
91
92
        self.experts = FusedMoE(
            num_experts=num_experts,
            top_k=top_k,
            hidden_size=hidden_size,
            intermediate_size=intermediate_size,
            params_dtype=params_dtype,
            reduce_results=True,
            renormalize=True,
            quant_config=quant_config,
            tp_size=tp_size,
            prefix=f"{prefix}.experts",
93
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
94

Lianmin Zheng's avatar
Lianmin Zheng committed
95
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
96
97
        # NOTE: hidden_states can have either 1D or 2D shape.
        orig_shape = hidden_states.shape
Lianmin Zheng's avatar
Lianmin Zheng committed
98
99
        hidden_states = hidden_states.view(-1, self.hidden_size)
        # router_logits: (num_tokens, n_experts)
Lianmin Zheng's avatar
Lianmin Zheng committed
100
        router_logits, _ = self.gate(hidden_states)
101
102
        final_hidden_states = self.experts(hidden_states, router_logits)
        return final_hidden_states.view(orig_shape)
Lianmin Zheng's avatar
Lianmin Zheng committed
103
104
105
106
107
108
109
110
111
112
113


class MixtralAttention(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        layer_id: int = 0,
        max_position: int = 4096 * 32,
        rope_theta: float = 10000,
114
        quant_config: Optional[QuantizationConfig] = None,
115
        prefix: str = "",
Lianmin Zheng's avatar
Lianmin Zheng committed
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta

        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
145
            quant_config=quant_config,
146
            prefix=f"{prefix}.qkv_proj",
Lianmin Zheng's avatar
Lianmin Zheng committed
147
148
149
150
151
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
152
            quant_config=quant_config,
153
            prefix=f"{prefix}.o_proj",
Lianmin Zheng's avatar
Lianmin Zheng committed
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
        )
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position,
            base=int(self.rope_theta),
            is_neox_style=True,
        )
        self.attn = RadixAttention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            layer_id=layer_id,
        )

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        input_metadata: InputMetadata,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
        attn_output = self.attn(q, k, v, input_metadata)
        output, _ = self.o_proj(attn_output)
        return output


class MixtralDecoderLayer(nn.Module):
    def __init__(
        self,
        config: MixtralConfig,
        layer_id: int = 0,
189
        quant_config: Optional[QuantizationConfig] = None,
190
        prefix: str = "",
Lianmin Zheng's avatar
Lianmin Zheng committed
191
192
193
194
195
196
197
198
199
200
201
202
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        # Requires transformers > 4.32.0
        rope_theta = getattr(config, "rope_theta", 10000)
        self.self_attn = MixtralAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            max_position=config.max_position_embeddings,
            num_kv_heads=config.num_key_value_heads,
            layer_id=layer_id,
            rope_theta=rope_theta,
203
            quant_config=quant_config,
204
            prefix=f"{prefix}.self_attn",
Lianmin Zheng's avatar
Lianmin Zheng committed
205
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
206
207
208
209
210
        self.block_sparse_moe = MixtralMoE(
            num_experts=config.num_local_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            intermediate_size=config.intermediate_size,
211
            quant_config=quant_config,
212
            prefix=f"{prefix}.block_sparse_moe",
213
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        input_metadata: InputMetadata,
        residual: Optional[torch.Tensor],
    ) -> torch.Tensor:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            input_metadata=input_metadata,
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
        hidden_states = self.block_sparse_moe(hidden_states)
        return hidden_states, residual


class MixtralModel(nn.Module):
    def __init__(
        self,
        config: MixtralConfig,
248
        quant_config: Optional[QuantizationConfig] = None,
249
        prefix: str = "",
Lianmin Zheng's avatar
Lianmin Zheng committed
250
251
252
253
254
255
256
257
258
259
260
    ) -> None:
        super().__init__()
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
        self.layers = nn.ModuleList(
            [
261
262
263
                MixtralDecoderLayer(
                    config, i, quant_config=quant_config, prefix=f"{prefix}.layers"
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
264
265
266
267
268
269
270
271
272
273
                for i in range(config.num_hidden_layers)
            ]
        )
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        input_metadata: InputMetadata,
274
        input_embeds: torch.Tensor = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
275
    ) -> torch.Tensor:
276
        if input_embeds is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
277
278
            hidden_states = self.embed_tokens(input_ids)
        else:
279
            hidden_states = input_embeds
Lianmin Zheng's avatar
Lianmin Zheng committed
280
281
282
283
284
285
286
287
288
289
290
        residual = None
        for i in range(len(self.layers)):
            layer = self.layers[i]
            hidden_states, residual = layer(
                positions, hidden_states, input_metadata, residual
            )
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states


class MixtralForCausalLM(nn.Module):
291

Lianmin Zheng's avatar
Lianmin Zheng committed
292
293
294
    def __init__(
        self,
        config: MixtralConfig,
295
        quant_config: Optional[QuantizationConfig] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
296
        cache_config: Optional[CacheConfig] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
297
298
299
    ) -> None:
        super().__init__()
        self.config = config
300
        self.quant_config = quant_config
301
        self.torchao_config = global_server_args_dict["torchao_config"]
302
        self.model = MixtralModel(config, quant_config=quant_config, prefix="model")
Lianmin Zheng's avatar
Lianmin Zheng committed
303
304
305
306
307
308
309
310
        self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
        self.logits_processor = LogitsProcessor(config)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        input_metadata: InputMetadata,
311
        input_embeds: torch.Tensor = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
312
    ) -> torch.Tensor:
313
        hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
314
        return self.logits_processor(
Lianmin Zheng's avatar
Lianmin Zheng committed
315
316
317
            input_ids, hidden_states, self.lm_head.weight, input_metadata
        )

318
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
Lianmin Zheng's avatar
Lianmin Zheng committed
319
320
321
322
323
324
325
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]

326
327
328
329
330
331
332
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
        expert_params_mapping = FusedMoE.make_expert_params_mapping(
            ckpt_gate_proj_name="w1",
            ckpt_down_proj_name="w2",
            ckpt_up_proj_name="w3",
            num_experts=self.config.num_local_experts,
333
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
334

Lianmin Zheng's avatar
Lianmin Zheng committed
335
        params_dict = dict(self.named_parameters())
336
        for name, loaded_weight in weights:
Lianmin Zheng's avatar
Lianmin Zheng committed
337
338
            if "rotary_emb.inv_freq" in name:
                continue
Lianmin Zheng's avatar
Lianmin Zheng committed
339

340
            for param_name, weight_name, shard_id in stacked_params_mapping:
Lianmin Zheng's avatar
Lianmin Zheng committed
341
342
343
344
345
346
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
347

Lianmin Zheng's avatar
Lianmin Zheng committed
348
349
350
351
352
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
353
354
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
Lianmin Zheng's avatar
Lianmin Zheng committed
355
356
357
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
358

Lianmin Zheng's avatar
Lianmin Zheng committed
359
360
                    param = params_dict[name]
                    weight_loader = param.weight_loader
361
                    weight_loader(
362
363
                        param,
                        loaded_weight,
364
                        name,
365
366
                        shard_id=shard_id,
                        expert_id=expert_id,
367
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
368
369
370
371
372
                    break
                else:
                    # Skip loading extra bias for GPTQ models.
                    if name.endswith(".bias") and name not in params_dict:
                        continue
373
374
375
                    if name is None:
                        continue

Lianmin Zheng's avatar
Lianmin Zheng committed
376
                    param = params_dict[name]
377
378
379
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
380
381
                    weight_loader(param, loaded_weight)

382
383
        apply_torchao_config_(self, params_dict, set(["proj.weight"]))

Lianmin Zheng's avatar
Lianmin Zheng committed
384

Cody Yu's avatar
Cody Yu committed
385
EntryClass = MixtralForCausalLM