mixtral.py 16.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14

Lianmin Zheng's avatar
Lianmin Zheng committed
15
# Adapted from
Lianmin Zheng's avatar
Lianmin Zheng committed
16
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
Lianmin Zheng's avatar
Lianmin Zheng committed
17
"""Inference-only Mixtral model."""
18

19
20
import logging
from typing import Iterable, Optional, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
21
22
23
24

import torch
from torch import nn
from transformers import MixtralConfig
25
26

from sglang.srt.distributed import (
27
    get_moe_expert_parallel_world_size,
28
    get_pp_group,
xiaobochen's avatar
xiaobochen committed
29
30
31
    get_tensor_model_parallel_world_size,
    tensor_model_parallel_all_reduce,
)
32
from sglang.srt.layers.layernorm import RMSNorm
33
34
35
36
37
from sglang.srt.layers.linear import (
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
Liangsheng Yin's avatar
Liangsheng Yin committed
38
from sglang.srt.layers.logits_processor import LogitsProcessor
Ke Bao's avatar
Ke Bao committed
39
40
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
41
from sglang.srt.layers.moe.topk import TopK
42
from sglang.srt.layers.quantization.base_config import QuantizationConfig
Liangsheng Yin's avatar
Liangsheng Yin committed
43
from sglang.srt.layers.radix_attention import RadixAttention
44
from sglang.srt.layers.rotary_embedding import get_rope
45
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
46
47
48
49
from sglang.srt.layers.vocab_parallel_embedding import (
    ParallelLMHead,
    VocabParallelEmbedding,
)
50
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
51
from sglang.srt.model_loader.weight_utils import default_weight_loader
52
53
54
from sglang.srt.utils import add_prefix, make_layers

logger = logging.getLogger(__name__)
Liangsheng Yin's avatar
Liangsheng Yin committed
55

Lianmin Zheng's avatar
Lianmin Zheng committed
56

Lianmin Zheng's avatar
Lianmin Zheng committed
57
58
59
class MixtralMoE(nn.Module):
    """A tensor-parallel MoE implementation for Mixtral that shards each expert
    across all ranks.
Lianmin Zheng's avatar
Lianmin Zheng committed
60

Lianmin Zheng's avatar
Lianmin Zheng committed
61
62
63
64
    Each expert's weights are sharded across all ranks and a fused MoE
    kernel is used for the forward pass, and finally we reduce the outputs
    across ranks.
    """
Lianmin Zheng's avatar
Lianmin Zheng committed
65
66
67

    def __init__(
        self,
Lianmin Zheng's avatar
Lianmin Zheng committed
68
69
70
71
        num_experts: int,
        top_k: int,
        hidden_size: int,
        intermediate_size: int,
Cheng Wan's avatar
Cheng Wan committed
72
        layer_id: int,
Lianmin Zheng's avatar
Lianmin Zheng committed
73
        params_dtype: Optional[torch.dtype] = None,
74
        quant_config: Optional[QuantizationConfig] = None,
75
76
        tp_size: Optional[int] = None,
        prefix: str = "",
Lianmin Zheng's avatar
Lianmin Zheng committed
77
78
    ):
        super().__init__()
xiaobochen's avatar
xiaobochen committed
79
        self.tp_size = get_tensor_model_parallel_world_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
80
81
82
        self.hidden_size = hidden_size

        # Gate always runs at half / full precision for now.
83
        self.gate = ReplicatedLinear(
84
85
            hidden_size,
            num_experts,
86
            bias=False,
87
            params_dtype=params_dtype,
88
            quant_config=None,
89
            prefix=add_prefix("gate", prefix),
90
        )
91
92
93
94
95
96

        self.topk = TopK(
            top_k=top_k,
            renormalize=True,
        )

97
        MoEImpl = EPMoE if get_moe_expert_parallel_world_size() > 1 else FusedMoE
xiaobochen's avatar
xiaobochen committed
98
        self.experts = MoEImpl(
99
100
            num_experts=num_experts,
            top_k=top_k,
Cheng Wan's avatar
Cheng Wan committed
101
            layer_id=layer_id,
102
103
104
105
            hidden_size=hidden_size,
            intermediate_size=intermediate_size,
            params_dtype=params_dtype,
            quant_config=quant_config,
106
            prefix=add_prefix("experts", prefix),
107
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
108

Lianmin Zheng's avatar
Lianmin Zheng committed
109
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
110
111
        # NOTE: hidden_states can have either 1D or 2D shape.
        orig_shape = hidden_states.shape
Lianmin Zheng's avatar
Lianmin Zheng committed
112
113
        hidden_states = hidden_states.view(-1, self.hidden_size)
        # router_logits: (num_tokens, n_experts)
Lianmin Zheng's avatar
Lianmin Zheng committed
114
        router_logits, _ = self.gate(hidden_states)
115
116
        topk_output = self.topk(hidden_states, router_logits)
        final_hidden_states = self.experts(hidden_states, topk_output)
xiaobochen's avatar
xiaobochen committed
117
118
        if self.tp_size > 1:
            final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
119
        return final_hidden_states.view(orig_shape)
Lianmin Zheng's avatar
Lianmin Zheng committed
120
121
122
123
124
125
126
127
128
129
130


class MixtralAttention(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        layer_id: int = 0,
        max_position: int = 4096 * 32,
        rope_theta: float = 10000,
131
        quant_config: Optional[QuantizationConfig] = None,
132
        prefix: str = "",
Lianmin Zheng's avatar
Lianmin Zheng committed
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta

        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
162
            quant_config=quant_config,
163
            prefix=add_prefix("qkv_proj", prefix),
Lianmin Zheng's avatar
Lianmin Zheng committed
164
165
166
167
168
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
169
            quant_config=quant_config,
170
            prefix=add_prefix("o_proj", prefix),
Lianmin Zheng's avatar
Lianmin Zheng committed
171
172
173
174
175
176
177
178
179
180
181
182
183
184
        )
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position,
            base=int(self.rope_theta),
            is_neox_style=True,
        )
        self.attn = RadixAttention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            layer_id=layer_id,
185
            quant_config=quant_config,
186
            prefix=add_prefix("attn", prefix),
Lianmin Zheng's avatar
Lianmin Zheng committed
187
188
189
190
191
192
        )

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
193
        forward_batch: ForwardBatch,
Lianmin Zheng's avatar
Lianmin Zheng committed
194
195
196
197
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
198
        attn_output = self.attn(q, k, v, forward_batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
199
200
201
202
203
204
205
206
207
        output, _ = self.o_proj(attn_output)
        return output


class MixtralDecoderLayer(nn.Module):
    def __init__(
        self,
        config: MixtralConfig,
        layer_id: int = 0,
208
        quant_config: Optional[QuantizationConfig] = None,
209
        prefix: str = "",
Lianmin Zheng's avatar
Lianmin Zheng committed
210
211
212
213
214
215
216
217
218
219
220
221
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        # Requires transformers > 4.32.0
        rope_theta = getattr(config, "rope_theta", 10000)
        self.self_attn = MixtralAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            max_position=config.max_position_embeddings,
            num_kv_heads=config.num_key_value_heads,
            layer_id=layer_id,
            rope_theta=rope_theta,
222
            quant_config=quant_config,
223
            prefix=add_prefix("self_attn", prefix),
Lianmin Zheng's avatar
Lianmin Zheng committed
224
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
225
226
227
228
229
        self.block_sparse_moe = MixtralMoE(
            num_experts=config.num_local_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            intermediate_size=config.intermediate_size,
Cheng Wan's avatar
Cheng Wan committed
230
            layer_id=layer_id,
231
            quant_config=quant_config,
232
            prefix=add_prefix("block_sparse_moe", prefix),
233
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
234
235
236
237
238
239
240
241
242
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
243
        forward_batch: ForwardBatch,
Lianmin Zheng's avatar
Lianmin Zheng committed
244
245
246
247
248
249
250
251
252
253
254
        residual: Optional[torch.Tensor],
    ) -> torch.Tensor:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
255
            forward_batch=forward_batch,
Lianmin Zheng's avatar
Lianmin Zheng committed
256
257
258
259
260
261
262
263
264
265
266
267
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
        hidden_states = self.block_sparse_moe(hidden_states)
        return hidden_states, residual


class MixtralModel(nn.Module):
    def __init__(
        self,
        config: MixtralConfig,
268
        quant_config: Optional[QuantizationConfig] = None,
269
        prefix: str = "",
Lianmin Zheng's avatar
Lianmin Zheng committed
270
271
272
273
    ) -> None:
        super().__init__()
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size
274
        self.pp_group = get_pp_group()
Lianmin Zheng's avatar
Lianmin Zheng committed
275

276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
        if self.pp_group.is_first_rank:
            self.embed_tokens = VocabParallelEmbedding(
                config.vocab_size,
                config.hidden_size,
                prefix=add_prefix("embed_tokens", prefix),
            )
        else:
            self.embed_tokens = PPMissingLayer()

        self.layers, self.start_layer, self.end_layer = make_layers(
            config.num_hidden_layers,
            lambda idx, prefix: MixtralDecoderLayer(
                config=config, quant_config=quant_config, layer_id=idx, prefix=prefix
            ),
            pp_rank=self.pp_group.rank_in_group,
            pp_size=self.pp_group.world_size,
            prefix="layers",
            return_tuple=True,
Lianmin Zheng's avatar
Lianmin Zheng committed
294
        )
295
296
297
298
299

        if self.pp_group.is_last_rank:
            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        else:
            self.norm = PPMissingLayer(return_tuple=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
300
301
302
303
304

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
305
        forward_batch: ForwardBatch,
306
        input_embeds: torch.Tensor = None,
307
308
309
310
311
312
313
314
        pp_proxy_tensors: Optional[PPProxyTensors] = None,
    ) -> Union[torch.Tensor, PPProxyTensors]:
        if self.pp_group.is_first_rank:
            if input_embeds is None:
                hidden_states = self.embed_tokens(input_ids)
            else:
                hidden_states = input_embeds
            residual = None
Lianmin Zheng's avatar
Lianmin Zheng committed
315
        else:
316
317
318
319
320
            assert pp_proxy_tensors is not None
            hidden_states = pp_proxy_tensors["hidden_states"]
            residual = pp_proxy_tensors["residual"]

        for i in range(self.start_layer, self.end_layer):
Lianmin Zheng's avatar
Lianmin Zheng committed
321
322
            layer = self.layers[i]
            hidden_states, residual = layer(
323
                positions, hidden_states, forward_batch, residual
Lianmin Zheng's avatar
Lianmin Zheng committed
324
            )
325
326
327
328
329
330
331
332
333
334
335

        if not self.pp_group.is_last_rank:
            return PPProxyTensors(
                {
                    "hidden_states": hidden_states,
                    "residual": residual,
                }
            )
        else:
            hidden_states, _ = self.norm(hidden_states, residual)

Lianmin Zheng's avatar
Lianmin Zheng committed
336
337
338
339
        return hidden_states


class MixtralForCausalLM(nn.Module):
340

Lianmin Zheng's avatar
Lianmin Zheng committed
341
342
343
    def __init__(
        self,
        config: MixtralConfig,
344
        quant_config: Optional[QuantizationConfig] = None,
345
        prefix: str = "",
Lianmin Zheng's avatar
Lianmin Zheng committed
346
347
    ) -> None:
        super().__init__()
348
        self.pp_group = get_pp_group()
Lianmin Zheng's avatar
Lianmin Zheng committed
349
        self.config = config
350
        self.quant_config = quant_config
351
352
353
354
355
356
        self.model = MixtralModel(
            config, quant_config=quant_config, prefix=add_prefix("model", prefix)
        )
        self.lm_head = ParallelLMHead(
            config.vocab_size, config.hidden_size, prefix=add_prefix("lm_head", prefix)
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
357
358
359
360
361
362
        self.logits_processor = LogitsProcessor(config)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
363
        forward_batch: ForwardBatch,
364
        input_embeds: torch.Tensor = None,
365
        pp_proxy_tensors: Optional[PPProxyTensors] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
366
    ) -> torch.Tensor:
367
368
369
370
371
372
        hidden_states = self.model(
            input_ids,
            positions,
            forward_batch,
            input_embeds,
            pp_proxy_tensors=pp_proxy_tensors,
Lianmin Zheng's avatar
Lianmin Zheng committed
373
374
        )

375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
        if self.pp_group.is_last_rank:
            return self.logits_processor(
                input_ids, hidden_states, self.lm_head, forward_batch
            )
        else:
            return hidden_states

    @property
    def start_layer(self):
        return self.model.start_layer

    @property
    def end_layer(self):
        return self.model.end_layer

390
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
Lianmin Zheng's avatar
Lianmin Zheng committed
391
392
393
394
395
396
397
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]

398
399
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
400
        expert_params_mapping = FusedMoE.make_expert_params_mapping(
401
402
403
404
            ckpt_gate_proj_name="w1",
            ckpt_down_proj_name="w2",
            ckpt_up_proj_name="w3",
            num_experts=self.config.num_local_experts,
405
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
406

Lianmin Zheng's avatar
Lianmin Zheng committed
407
        params_dict = dict(self.named_parameters())
408
        for name, loaded_weight in weights:
409
410
411
412
413
414
415
416
417
418
419
            layer_id = get_layer_id(name)
            if (
                layer_id is not None
                and hasattr(self.model, "start_layer")
                and (
                    layer_id < self.model.start_layer
                    or layer_id >= self.model.end_layer
                )
            ):
                continue

Lianmin Zheng's avatar
Lianmin Zheng committed
420
421
            if "rotary_emb.inv_freq" in name:
                continue
Lianmin Zheng's avatar
Lianmin Zheng committed
422

423
            for param_name, weight_name, shard_id in stacked_params_mapping:
Lianmin Zheng's avatar
Lianmin Zheng committed
424
425
426
427
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
428
429
430
                if (
                    name.endswith(".bias") or name.endswith("_bias")
                ) and name not in params_dict:
Lianmin Zheng's avatar
Lianmin Zheng committed
431
                    continue
432

Lianmin Zheng's avatar
Lianmin Zheng committed
433
434
435
436
437
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
438
439
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
Lianmin Zheng's avatar
Lianmin Zheng committed
440
441
442
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
443

444
445
446
447
                    if (
                        name.endswith(".bias") or name.endswith("_bias")
                    ) and name not in params_dict:
                        continue
Lianmin Zheng's avatar
Lianmin Zheng committed
448
449
                    param = params_dict[name]
                    weight_loader = param.weight_loader
450
                    weight_loader(
451
452
                        param,
                        loaded_weight,
453
                        name,
454
455
                        shard_id=shard_id,
                        expert_id=expert_id,
456
                    )
Lianmin Zheng's avatar
Lianmin Zheng committed
457
458
459
                    break
                else:
                    # Skip loading extra bias for GPTQ models.
460
461
462
                    if (
                        name.endswith(".bias") or name.endswith("_bias")
                    ) and name not in params_dict:
Lianmin Zheng's avatar
Lianmin Zheng committed
463
                        continue
464
465
466
                    # Skip loading kv_scale from ckpts towards new design.
                    if name.endswith(".kv_scale") and name not in params_dict:
                        continue
467
468
469
                    if name is None:
                        continue

470
471
472
473
474
475
476
477
                    if name in params_dict.keys():
                        param = params_dict[name]
                        weight_loader = getattr(
                            param, "weight_loader", default_weight_loader
                        )
                        weight_loader(param, loaded_weight)
                    else:
                        logger.warning(f"Parameter {name} not found in params_dict")
Lianmin Zheng's avatar
Lianmin Zheng committed
478
479


Cody Yu's avatar
Cody Yu committed
480
EntryClass = MixtralForCausalLM