mixtral.py 14.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14

Lianmin Zheng's avatar
Lianmin Zheng committed
15
# Adapted from
Lianmin Zheng's avatar
Lianmin Zheng committed
16
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
Lianmin Zheng's avatar
Lianmin Zheng committed
17
"""Inference-only Mixtral model."""
18

19
from typing import Iterable, Optional, Tuple
Lianmin Zheng's avatar
Lianmin Zheng committed
20
21
22
23

import torch
from torch import nn
from transformers import MixtralConfig
24
25
26
from vllm.model_executor.layers.rotary_embedding import get_rope

from sglang.srt.distributed import (
xiaobochen's avatar
xiaobochen committed
27
28
29
    get_tensor_model_parallel_world_size,
    tensor_model_parallel_all_reduce,
)
30
from sglang.srt.layers.layernorm import RMSNorm
31
32
33
34
35
from sglang.srt.layers.linear import (
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
Liangsheng Yin's avatar
Liangsheng Yin committed
36
from sglang.srt.layers.logits_processor import LogitsProcessor
Ke Bao's avatar
Ke Bao committed
37
38
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
39
from sglang.srt.layers.quantization.base_config import QuantizationConfig
Liangsheng Yin's avatar
Liangsheng Yin committed
40
from sglang.srt.layers.radix_attention import RadixAttention
41
42
43
44
from sglang.srt.layers.vocab_parallel_embedding import (
    ParallelLMHead,
    VocabParallelEmbedding,
)
xiaobochen's avatar
xiaobochen committed
45
from sglang.srt.managers.schedule_batch import global_server_args_dict
46
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
47
from sglang.srt.model_loader.weight_utils import default_weight_loader
Liangsheng Yin's avatar
Liangsheng Yin committed
48

Lianmin Zheng's avatar
Lianmin Zheng committed
49

Lianmin Zheng's avatar
Lianmin Zheng committed
50
51
52
class MixtralMoE(nn.Module):
    """A tensor-parallel MoE implementation for Mixtral that shards each expert
    across all ranks.
Lianmin Zheng's avatar
Lianmin Zheng committed
53

Lianmin Zheng's avatar
Lianmin Zheng committed
54
55
56
57
    Each expert's weights are sharded across all ranks and a fused MoE
    kernel is used for the forward pass, and finally we reduce the outputs
    across ranks.
    """
Lianmin Zheng's avatar
Lianmin Zheng committed
58
59
60

    def __init__(
        self,
Lianmin Zheng's avatar
Lianmin Zheng committed
61
62
63
64
65
        num_experts: int,
        top_k: int,
        hidden_size: int,
        intermediate_size: int,
        params_dtype: Optional[torch.dtype] = None,
66
        quant_config: Optional[QuantizationConfig] = None,
67
68
        tp_size: Optional[int] = None,
        prefix: str = "",
Lianmin Zheng's avatar
Lianmin Zheng committed
69
70
    ):
        super().__init__()
xiaobochen's avatar
xiaobochen committed
71
        self.tp_size = get_tensor_model_parallel_world_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
72
73
74
        self.hidden_size = hidden_size

        # Gate always runs at half / full precision for now.
75
        self.gate = ReplicatedLinear(
76
77
            hidden_size,
            num_experts,
78
            bias=False,
79
            params_dtype=params_dtype,
80
            quant_config=None,
81
            prefix=f"{prefix}.gate",
82
        )
xiaobochen's avatar
xiaobochen committed
83
84
        MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
        self.experts = MoEImpl(
85
86
87
88
89
90
91
92
93
            num_experts=num_experts,
            top_k=top_k,
            hidden_size=hidden_size,
            intermediate_size=intermediate_size,
            params_dtype=params_dtype,
            renormalize=True,
            quant_config=quant_config,
            tp_size=tp_size,
            prefix=f"{prefix}.experts",
94
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
95

Lianmin Zheng's avatar
Lianmin Zheng committed
96
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
97
98
        # NOTE: hidden_states can have either 1D or 2D shape.
        orig_shape = hidden_states.shape
Lianmin Zheng's avatar
Lianmin Zheng committed
99
100
        hidden_states = hidden_states.view(-1, self.hidden_size)
        # router_logits: (num_tokens, n_experts)
Lianmin Zheng's avatar
Lianmin Zheng committed
101
        router_logits, _ = self.gate(hidden_states)
102
        final_hidden_states = self.experts(hidden_states, router_logits)
xiaobochen's avatar
xiaobochen committed
103
104
        if self.tp_size > 1:
            final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
105
        return final_hidden_states.view(orig_shape)
Lianmin Zheng's avatar
Lianmin Zheng committed
106
107
108
109
110
111
112
113
114
115
116


class MixtralAttention(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        layer_id: int = 0,
        max_position: int = 4096 * 32,
        rope_theta: float = 10000,
117
        quant_config: Optional[QuantizationConfig] = None,
118
        prefix: str = "",
Lianmin Zheng's avatar
Lianmin Zheng committed
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta

        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
148
            quant_config=quant_config,
149
            prefix=f"{prefix}.qkv_proj",
Lianmin Zheng's avatar
Lianmin Zheng committed
150
151
152
153
154
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
155
            quant_config=quant_config,
156
            prefix=f"{prefix}.o_proj",
Lianmin Zheng's avatar
Lianmin Zheng committed
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
        )
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position,
            base=int(self.rope_theta),
            is_neox_style=True,
        )
        self.attn = RadixAttention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            layer_id=layer_id,
        )

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
177
        forward_batch: ForwardBatch,
Lianmin Zheng's avatar
Lianmin Zheng committed
178
179
180
181
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
182
        attn_output = self.attn(q, k, v, forward_batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
183
184
185
186
187
188
189
190
191
        output, _ = self.o_proj(attn_output)
        return output


class MixtralDecoderLayer(nn.Module):
    def __init__(
        self,
        config: MixtralConfig,
        layer_id: int = 0,
192
        quant_config: Optional[QuantizationConfig] = None,
193
        prefix: str = "",
Lianmin Zheng's avatar
Lianmin Zheng committed
194
195
196
197
198
199
200
201
202
203
204
205
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        # Requires transformers > 4.32.0
        rope_theta = getattr(config, "rope_theta", 10000)
        self.self_attn = MixtralAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            max_position=config.max_position_embeddings,
            num_kv_heads=config.num_key_value_heads,
            layer_id=layer_id,
            rope_theta=rope_theta,
206
            quant_config=quant_config,
207
            prefix=f"{prefix}.self_attn",
Lianmin Zheng's avatar
Lianmin Zheng committed
208
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
209
210
211
212
213
        self.block_sparse_moe = MixtralMoE(
            num_experts=config.num_local_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            intermediate_size=config.intermediate_size,
214
            quant_config=quant_config,
215
            prefix=f"{prefix}.block_sparse_moe",
216
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
217
218
219
220
221
222
223
224
225
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
226
        forward_batch: ForwardBatch,
Lianmin Zheng's avatar
Lianmin Zheng committed
227
228
229
230
231
232
233
234
235
236
237
        residual: Optional[torch.Tensor],
    ) -> torch.Tensor:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
238
            forward_batch=forward_batch,
Lianmin Zheng's avatar
Lianmin Zheng committed
239
240
241
242
243
244
245
246
247
248
249
250
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
        hidden_states = self.block_sparse_moe(hidden_states)
        return hidden_states, residual


class MixtralModel(nn.Module):
    def __init__(
        self,
        config: MixtralConfig,
251
        quant_config: Optional[QuantizationConfig] = None,
252
        prefix: str = "",
Lianmin Zheng's avatar
Lianmin Zheng committed
253
254
255
256
257
258
259
260
261
262
263
    ) -> None:
        super().__init__()
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
        self.layers = nn.ModuleList(
            [
264
265
266
                MixtralDecoderLayer(
                    config, i, quant_config=quant_config, prefix=f"{prefix}.layers"
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
267
268
269
270
271
272
273
274
275
                for i in range(config.num_hidden_layers)
            ]
        )
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
276
        forward_batch: ForwardBatch,
277
        input_embeds: torch.Tensor = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
278
    ) -> torch.Tensor:
279
        if input_embeds is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
280
281
            hidden_states = self.embed_tokens(input_ids)
        else:
282
            hidden_states = input_embeds
Lianmin Zheng's avatar
Lianmin Zheng committed
283
284
285
286
        residual = None
        for i in range(len(self.layers)):
            layer = self.layers[i]
            hidden_states, residual = layer(
287
                positions, hidden_states, forward_batch, residual
Lianmin Zheng's avatar
Lianmin Zheng committed
288
289
290
291
292
293
            )
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states


class MixtralForCausalLM(nn.Module):
294

Lianmin Zheng's avatar
Lianmin Zheng committed
295
296
297
    def __init__(
        self,
        config: MixtralConfig,
298
        quant_config: Optional[QuantizationConfig] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
299
300
301
    ) -> None:
        super().__init__()
        self.config = config
302
        self.quant_config = quant_config
303
        self.model = MixtralModel(config, quant_config=quant_config, prefix="model")
Lianmin Zheng's avatar
Lianmin Zheng committed
304
305
306
307
308
309
310
        self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
        self.logits_processor = LogitsProcessor(config)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
311
        forward_batch: ForwardBatch,
312
        input_embeds: torch.Tensor = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
313
    ) -> torch.Tensor:
314
        hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
315
        return self.logits_processor(
316
            input_ids, hidden_states, self.lm_head, forward_batch
Lianmin Zheng's avatar
Lianmin Zheng committed
317
318
        )

319
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
Lianmin Zheng's avatar
Lianmin Zheng committed
320
321
322
323
324
325
326
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]

327
328
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
xiaobochen's avatar
xiaobochen committed
329
330
        MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
        expert_params_mapping = MoEImpl.make_expert_params_mapping(
331
332
333
334
            ckpt_gate_proj_name="w1",
            ckpt_down_proj_name="w2",
            ckpt_up_proj_name="w3",
            num_experts=self.config.num_local_experts,
335
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
336

Lianmin Zheng's avatar
Lianmin Zheng committed
337
        params_dict = dict(self.named_parameters())
338
        for name, loaded_weight in weights:
Lianmin Zheng's avatar
Lianmin Zheng committed
339
340
            if "rotary_emb.inv_freq" in name:
                continue
Lianmin Zheng's avatar
Lianmin Zheng committed
341

342
            for param_name, weight_name, shard_id in stacked_params_mapping:
Lianmin Zheng's avatar
Lianmin Zheng committed
343
344
345
346
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
347
348
349
                if (
                    name.endswith(".bias") or name.endswith("_bias")
                ) and name not in params_dict:
Lianmin Zheng's avatar
Lianmin Zheng committed
350
                    continue
351

Lianmin Zheng's avatar
Lianmin Zheng committed
352
353
354
355
356
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
357
358
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
Lianmin Zheng's avatar
Lianmin Zheng committed
359
360
361
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
362

363
364
365
366
                    if (
                        name.endswith(".bias") or name.endswith("_bias")
                    ) and name not in params_dict:
                        continue
Lianmin Zheng's avatar
Lianmin Zheng committed
367
368
                    param = params_dict[name]
                    weight_loader = param.weight_loader
369
                    weight_loader(
370
371
                        param,
                        loaded_weight,
372
                        name,
373
374
                        shard_id=shard_id,
                        expert_id=expert_id,
375
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
376
377
378
                    break
                else:
                    # Skip loading extra bias for GPTQ models.
379
380
381
                    if (
                        name.endswith(".bias") or name.endswith("_bias")
                    ) and name not in params_dict:
Lianmin Zheng's avatar
Lianmin Zheng committed
382
                        continue
383
384
385
                    # Skip loading kv_scale from ckpts towards new design.
                    if name.endswith(".kv_scale") and name not in params_dict:
                        continue
386
387
388
                    if name is None:
                        continue

Lianmin Zheng's avatar
Lianmin Zheng committed
389
                    param = params_dict[name]
390
391
392
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
393
394
395
                    weight_loader(param, loaded_weight)


Cody Yu's avatar
Cody Yu committed
396
EntryClass = MixtralForCausalLM