mixtral.py 13.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
# Adapted from
Lianmin Zheng's avatar
Lianmin Zheng committed
17
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
Lianmin Zheng's avatar
Lianmin Zheng committed
18
"""Inference-only Mixtral model."""
19
from typing import Iterable, Optional, Tuple
Lianmin Zheng's avatar
Lianmin Zheng committed
20
21
22
23

import torch
from torch import nn
from transformers import MixtralConfig
24
25
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.fused_moe import FusedMoE
Lianmin Zheng's avatar
Lianmin Zheng committed
26
from vllm.model_executor.layers.rotary_embedding import get_rope
27
28
from vllm.model_executor.model_loader.weight_utils import default_weight_loader

29
from sglang.srt.layers.layernorm import RMSNorm
30
31
32
33
34
from sglang.srt.layers.linear import (
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
Liangsheng Yin's avatar
Liangsheng Yin committed
35
from sglang.srt.layers.logits_processor import LogitsProcessor
36
from sglang.srt.layers.quantization.base_config import QuantizationConfig
Liangsheng Yin's avatar
Liangsheng Yin committed
37
from sglang.srt.layers.radix_attention import RadixAttention
38
from sglang.srt.layers.torchao_utils import apply_torchao_config_
39
40
41
42
from sglang.srt.layers.vocab_parallel_embedding import (
    ParallelLMHead,
    VocabParallelEmbedding,
)
43
from sglang.srt.managers.schedule_batch import global_server_args_dict
44
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
Liangsheng Yin's avatar
Liangsheng Yin committed
45

Lianmin Zheng's avatar
Lianmin Zheng committed
46

Lianmin Zheng's avatar
Lianmin Zheng committed
47
48
49
class MixtralMoE(nn.Module):
    """A tensor-parallel MoE implementation for Mixtral that shards each expert
    across all ranks.
Lianmin Zheng's avatar
Lianmin Zheng committed
50

Lianmin Zheng's avatar
Lianmin Zheng committed
51
52
53
54
    Each expert's weights are sharded across all ranks and a fused MoE
    kernel is used for the forward pass, and finally we reduce the outputs
    across ranks.
    """
Lianmin Zheng's avatar
Lianmin Zheng committed
55
56
57

    def __init__(
        self,
Lianmin Zheng's avatar
Lianmin Zheng committed
58
59
60
61
62
        num_experts: int,
        top_k: int,
        hidden_size: int,
        intermediate_size: int,
        params_dtype: Optional[torch.dtype] = None,
63
        quant_config: Optional[QuantizationConfig] = None,
64
65
        tp_size: Optional[int] = None,
        prefix: str = "",
Lianmin Zheng's avatar
Lianmin Zheng committed
66
67
    ):
        super().__init__()
Lianmin Zheng's avatar
Lianmin Zheng committed
68
69
70
        self.hidden_size = hidden_size

        # Gate always runs at half / full precision for now.
71
        self.gate = ReplicatedLinear(
72
73
            hidden_size,
            num_experts,
74
            bias=False,
75
            params_dtype=params_dtype,
76
            quant_config=None,
77
            prefix=f"{prefix}.gate",
78
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
79

80
81
82
83
84
85
86
87
88
89
90
        self.experts = FusedMoE(
            num_experts=num_experts,
            top_k=top_k,
            hidden_size=hidden_size,
            intermediate_size=intermediate_size,
            params_dtype=params_dtype,
            reduce_results=True,
            renormalize=True,
            quant_config=quant_config,
            tp_size=tp_size,
            prefix=f"{prefix}.experts",
91
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
92

Lianmin Zheng's avatar
Lianmin Zheng committed
93
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
94
95
        # NOTE: hidden_states can have either 1D or 2D shape.
        orig_shape = hidden_states.shape
Lianmin Zheng's avatar
Lianmin Zheng committed
96
97
        hidden_states = hidden_states.view(-1, self.hidden_size)
        # router_logits: (num_tokens, n_experts)
Lianmin Zheng's avatar
Lianmin Zheng committed
98
        router_logits, _ = self.gate(hidden_states)
99
100
        final_hidden_states = self.experts(hidden_states, router_logits)
        return final_hidden_states.view(orig_shape)
Lianmin Zheng's avatar
Lianmin Zheng committed
101
102
103
104
105
106
107
108
109
110
111


class MixtralAttention(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        layer_id: int = 0,
        max_position: int = 4096 * 32,
        rope_theta: float = 10000,
112
        quant_config: Optional[QuantizationConfig] = None,
113
        prefix: str = "",
Lianmin Zheng's avatar
Lianmin Zheng committed
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta

        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
143
            quant_config=quant_config,
144
            prefix=f"{prefix}.qkv_proj",
Lianmin Zheng's avatar
Lianmin Zheng committed
145
146
147
148
149
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
150
            quant_config=quant_config,
151
            prefix=f"{prefix}.o_proj",
Lianmin Zheng's avatar
Lianmin Zheng committed
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
        )
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position,
            base=int(self.rope_theta),
            is_neox_style=True,
        )
        self.attn = RadixAttention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            layer_id=layer_id,
        )

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
172
        forward_batch: ForwardBatch,
Lianmin Zheng's avatar
Lianmin Zheng committed
173
174
175
176
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
177
        attn_output = self.attn(q, k, v, forward_batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
178
179
180
181
182
183
184
185
186
        output, _ = self.o_proj(attn_output)
        return output


class MixtralDecoderLayer(nn.Module):
    def __init__(
        self,
        config: MixtralConfig,
        layer_id: int = 0,
187
        quant_config: Optional[QuantizationConfig] = None,
188
        prefix: str = "",
Lianmin Zheng's avatar
Lianmin Zheng committed
189
190
191
192
193
194
195
196
197
198
199
200
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        # Requires transformers > 4.32.0
        rope_theta = getattr(config, "rope_theta", 10000)
        self.self_attn = MixtralAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            max_position=config.max_position_embeddings,
            num_kv_heads=config.num_key_value_heads,
            layer_id=layer_id,
            rope_theta=rope_theta,
201
            quant_config=quant_config,
202
            prefix=f"{prefix}.self_attn",
Lianmin Zheng's avatar
Lianmin Zheng committed
203
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
204
205
206
207
208
        self.block_sparse_moe = MixtralMoE(
            num_experts=config.num_local_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            intermediate_size=config.intermediate_size,
209
            quant_config=quant_config,
210
            prefix=f"{prefix}.block_sparse_moe",
211
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
212
213
214
215
216
217
218
219
220
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
221
        forward_batch: ForwardBatch,
Lianmin Zheng's avatar
Lianmin Zheng committed
222
223
224
225
226
227
228
229
230
231
232
        residual: Optional[torch.Tensor],
    ) -> torch.Tensor:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
233
            forward_batch=forward_batch,
Lianmin Zheng's avatar
Lianmin Zheng committed
234
235
236
237
238
239
240
241
242
243
244
245
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
        hidden_states = self.block_sparse_moe(hidden_states)
        return hidden_states, residual


class MixtralModel(nn.Module):
    def __init__(
        self,
        config: MixtralConfig,
246
        quant_config: Optional[QuantizationConfig] = None,
247
        prefix: str = "",
Lianmin Zheng's avatar
Lianmin Zheng committed
248
249
250
251
252
253
254
255
256
257
258
    ) -> None:
        super().__init__()
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
        self.layers = nn.ModuleList(
            [
259
260
261
                MixtralDecoderLayer(
                    config, i, quant_config=quant_config, prefix=f"{prefix}.layers"
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
262
263
264
265
266
267
268
269
270
                for i in range(config.num_hidden_layers)
            ]
        )
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
271
        forward_batch: ForwardBatch,
272
        input_embeds: torch.Tensor = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
273
    ) -> torch.Tensor:
274
        if input_embeds is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
275
276
            hidden_states = self.embed_tokens(input_ids)
        else:
277
            hidden_states = input_embeds
Lianmin Zheng's avatar
Lianmin Zheng committed
278
279
280
281
        residual = None
        for i in range(len(self.layers)):
            layer = self.layers[i]
            hidden_states, residual = layer(
282
                positions, hidden_states, forward_batch, residual
Lianmin Zheng's avatar
Lianmin Zheng committed
283
284
285
286
287
288
            )
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states


class MixtralForCausalLM(nn.Module):
289

Lianmin Zheng's avatar
Lianmin Zheng committed
290
291
292
    def __init__(
        self,
        config: MixtralConfig,
293
        quant_config: Optional[QuantizationConfig] = None,
294
        cache_config=None,
Lianmin Zheng's avatar
Lianmin Zheng committed
295
296
297
    ) -> None:
        super().__init__()
        self.config = config
298
        self.quant_config = quant_config
299
        self.torchao_config = global_server_args_dict["torchao_config"]
300
        self.model = MixtralModel(config, quant_config=quant_config, prefix="model")
Lianmin Zheng's avatar
Lianmin Zheng committed
301
302
303
304
305
306
307
        self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
        self.logits_processor = LogitsProcessor(config)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
308
        forward_batch: ForwardBatch,
309
        input_embeds: torch.Tensor = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
310
    ) -> torch.Tensor:
311
        hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
312
        return self.logits_processor(
313
            input_ids, hidden_states, self.lm_head.weight, forward_batch
Lianmin Zheng's avatar
Lianmin Zheng committed
314
315
        )

316
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
Lianmin Zheng's avatar
Lianmin Zheng committed
317
318
319
320
321
322
323
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]

324
325
326
327
328
329
330
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
        expert_params_mapping = FusedMoE.make_expert_params_mapping(
            ckpt_gate_proj_name="w1",
            ckpt_down_proj_name="w2",
            ckpt_up_proj_name="w3",
            num_experts=self.config.num_local_experts,
331
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
332

Lianmin Zheng's avatar
Lianmin Zheng committed
333
        params_dict = dict(self.named_parameters())
334
        for name, loaded_weight in weights:
Lianmin Zheng's avatar
Lianmin Zheng committed
335
336
            if "rotary_emb.inv_freq" in name:
                continue
Lianmin Zheng's avatar
Lianmin Zheng committed
337

338
            for param_name, weight_name, shard_id in stacked_params_mapping:
Lianmin Zheng's avatar
Lianmin Zheng committed
339
340
341
342
343
344
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
345

Lianmin Zheng's avatar
Lianmin Zheng committed
346
347
348
349
350
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
351
352
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
Lianmin Zheng's avatar
Lianmin Zheng committed
353
354
355
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
356

Lianmin Zheng's avatar
Lianmin Zheng committed
357
358
                    param = params_dict[name]
                    weight_loader = param.weight_loader
359
                    weight_loader(
360
361
                        param,
                        loaded_weight,
362
                        name,
363
364
                        shard_id=shard_id,
                        expert_id=expert_id,
365
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
366
367
368
369
370
                    break
                else:
                    # Skip loading extra bias for GPTQ models.
                    if name.endswith(".bias") and name not in params_dict:
                        continue
371
372
373
                    # Skip loading kv_scale from ckpts towards new design.
                    if name.endswith(".kv_scale") and name not in params_dict:
                        continue
374
375
376
                    if name is None:
                        continue

Lianmin Zheng's avatar
Lianmin Zheng committed
377
                    param = params_dict[name]
378
379
380
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
381
382
                    weight_loader(param, loaded_weight)

383
384
        apply_torchao_config_(self, params_dict, set(["proj.weight"]))

Lianmin Zheng's avatar
Lianmin Zheng committed
385

Cody Yu's avatar
Cody Yu committed
386
EntryClass = MixtralForCausalLM