llama4.py 19.2 KB
Newer Older
Chang Su's avatar
Chang Su committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# Adapted from
# https://github.com/vllm-project/vllm/blob/v0.8.3/vllm/model_executor/models/llama4.py
"""Inference-only LLaMA model compatible with HuggingFace weights."""

import logging
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
from torch import nn
from transformers import Llama4TextConfig

from sglang.srt.distributed import (
    get_tensor_model_parallel_world_size,
    tensor_model_parallel_all_reduce,
)
30
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
fzyzcjy's avatar
fzyzcjy committed
31
32
33
from sglang.srt.layers.dp_attention import (
    get_attention_tp_rank,
    get_attention_tp_size,
34
    get_local_attention_dp_size,
35
    is_dp_attention_enabled,
fzyzcjy's avatar
fzyzcjy committed
36
)
Chang Su's avatar
Chang Su committed
37
38
39
40
41
42
43
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
44
from sglang.srt.layers.moe.topk import TopK
Chang Su's avatar
Chang Su committed
45
46
47
48
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
49
50
51
52
53
from sglang.srt.model_executor.forward_batch_info import (
    ForwardBatch,
    ForwardMode,
    PPProxyTensors,
)
Chang Su's avatar
Chang Su committed
54
from sglang.srt.models.llama import LlamaForCausalLM, LlamaMLP
55
56
57
58
59
60
61
62
63
from sglang.srt.utils import (
    add_prefix,
    fast_topk,
    get_compiler_backend,
    is_cuda,
    make_layers,
)

_is_cuda = is_cuda()
Chang Su's avatar
Chang Su committed
64
65
66
67
68
69
70
71
72
73
74
75
76
77

logger = logging.getLogger(__name__)


class Llama4MoE(nn.Module):

    @torch.compile(dynamic=True, backend=get_compiler_backend())
    @staticmethod
    def custom_routing_function(
        hidden_states: torch.Tensor,
        gating_output: torch.Tensor,
        topk: int,
        renormalize: bool,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
78
        router_scores_aK, router_indices_aK = fast_topk(gating_output, topk, dim=-1)
Chang Su's avatar
Chang Su committed
79
80
81
82
83
84
85
86
87
88
89
        router_scores_aK = torch.sigmoid(router_scores_aK.float()).to(
            hidden_states.dtype
        )
        return (
            router_scores_aK.view(-1).reshape(router_scores_aK.shape),
            router_indices_aK.to(torch.int32),
        )

    def __init__(
        self,
        config: Llama4TextConfig,
Cheng Wan's avatar
Cheng Wan committed
90
        layer_id: int,
Chang Su's avatar
Chang Su committed
91
92
93
94
95
96
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
        super().__init__()
        self.tp_size = get_tensor_model_parallel_world_size()
        self.top_k = config.num_experts_per_tok
97
        self.device_module = torch.get_device_module()
Chang Su's avatar
Chang Su committed
98
99
100
101
102
103
104
105
106
107

        intermediate_size_moe = config.intermediate_size
        self.router = ReplicatedLinear(
            config.hidden_size,
            config.num_local_experts,
            bias=False,
            quant_config=None,
            prefix=add_prefix("router", prefix),
        )

108
109
110
111
112
113
        self.topk = TopK(
            top_k=self.top_k,
            renormalize=False,
            custom_routing_function=Llama4MoE.custom_routing_function,
        )

Chang Su's avatar
Chang Su committed
114
115
116
117
        self.experts = FusedMoE(
            num_experts=config.num_local_experts,
            hidden_size=config.hidden_size,
            intermediate_size=intermediate_size_moe,
Cheng Wan's avatar
Cheng Wan committed
118
            layer_id=layer_id,
Chang Su's avatar
Chang Su committed
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
            reduce_results=False,
            quant_config=quant_config,
            apply_router_weight_on_input=True,
            prefix=add_prefix("experts", prefix),
        )

        self.shared_expert = LlamaMLP(
            hidden_size=config.hidden_size,
            intermediate_size=intermediate_size_moe,
            hidden_act="silu",
            quant_config=quant_config,
            prefix=add_prefix("shared_expert", prefix),
            reduce_results=False,  # We need to do scatter before reduce
        )

134
135
136
137
138
139
    def forward(
        self,
        hidden_states,
        forward_batch: ForwardBatch,
        use_reduce_scatter: bool = False,
    ):
140
141
142
143
144
145
        shared_out, routed_out = self._forward_core(
            hidden_states, forward_batch.forward_mode
        )

        out_aD = routed_out + shared_out

146
        if self.tp_size > 1 and not use_reduce_scatter:
147
148
149
150
151
            out_aD = tensor_model_parallel_all_reduce(out_aD)

        return out_aD

    def _forward_core(self, hidden_states, forward_mode: ForwardMode):
152
        if hidden_states.shape[0] < 4 and _is_cuda:
153
154
155
156
157
            return self._forward_core_shared_routed_overlap(hidden_states)
        else:
            return self._forward_core_normal(hidden_states)

    def _forward_core_normal(self, hidden_states):
Chang Su's avatar
Chang Su committed
158
159
160
        # router_scores: [num_tokens, num_experts]
        router_logits, _ = self.router(hidden_states)
        shared_out = self.shared_expert(hidden_states)
161
162
        topk_output = self.topk(hidden_states, router_logits)
        routed_out = self.experts(hidden_states, topk_output)
163
        return shared_out, routed_out
Chang Su's avatar
Chang Su committed
164

165
166
    def _forward_core_shared_routed_overlap(self, hidden_states):
        alt_stream = _get_or_create_alt_stream(self.device_module)
Chang Su's avatar
Chang Su committed
167

168
169
170
171
172
173
174
        alt_stream.wait_stream(self.device_module.current_stream())

        shared_out = self.shared_expert(hidden_states)

        with self.device_module.stream(alt_stream):
            # router_scores: [num_tokens, num_experts]
            router_logits, _ = self.router(hidden_states)
175
176
            topk_output = self.topk(hidden_states, router_logits)
            routed_out = self.experts(hidden_states, topk_output)
177
178
179
180
181
182
183
184
185
186
187
188
189
        self.device_module.current_stream().wait_stream(alt_stream)

        return shared_out, routed_out


_alt_stream = None


def _get_or_create_alt_stream(device_module):
    global _alt_stream
    if _alt_stream is None:
        _alt_stream = device_module.Stream()
    return _alt_stream
Chang Su's avatar
Chang Su committed
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211


class Llama4Attention(nn.Module):

    def __init__(
        self,
        config: Llama4TextConfig,
        layer_id: int,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
        max_position_embeddings: int = 8192,
        quant_config: Optional[QuantizationConfig] = None,
        bias: bool = False,
        bias_o_proj: bool = False,
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.layer_id = layer_id
        self.hidden_size = hidden_size
212
        self.use_rope = (layer_id + 1) % 4 != 0
Chang Su's avatar
Chang Su committed
213
        self.use_qk_norm = config.use_qk_norm and self.use_rope
fzyzcjy's avatar
fzyzcjy committed
214
215
216
217

        attn_tp_rank = get_attention_tp_rank()
        attn_tp_size = get_attention_tp_size()

Chang Su's avatar
Chang Su committed
218
        self.total_num_heads = num_heads
fzyzcjy's avatar
fzyzcjy committed
219
220
        assert self.total_num_heads % attn_tp_size == 0
        self.num_heads = self.total_num_heads // attn_tp_size
Chang Su's avatar
Chang Su committed
221
        self.total_num_kv_heads = num_kv_heads
fzyzcjy's avatar
fzyzcjy committed
222
        if self.total_num_kv_heads >= attn_tp_size:
Chang Su's avatar
Chang Su committed
223
224
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
fzyzcjy's avatar
fzyzcjy committed
225
            assert self.total_num_kv_heads % attn_tp_size == 0
Chang Su's avatar
Chang Su committed
226
227
228
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
fzyzcjy's avatar
fzyzcjy committed
229
230
            assert attn_tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // attn_tp_size)
Chang Su's avatar
Chang Su committed
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
        self.head_dim = config.head_dim
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.attn_temperature_tuning = config.attn_temperature_tuning
        self.floor_scale = config.floor_scale
        self.attn_scale = config.attn_scale
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings
        self.n_rep = self.num_heads // self.num_kv_heads
        self.qk_norm = (
            RMSNorm(
                hidden_size=self.head_dim,
                eps=config.rms_norm_eps,
            )
            if self.use_qk_norm
            else None
        )
249
250
251
252
253
254
255
256
257

        qkv_quant_config = quant_config
        o_quant_config = quant_config
        if quant_config and hasattr(quant_config, "ignore") and quant_config.ignore:
            if add_prefix("q_proj", prefix) in quant_config.ignore:
                qkv_quant_config = None
            if add_prefix("o_proj", prefix) in quant_config.ignore:
                o_quant_config = None

Chang Su's avatar
Chang Su committed
258
259
260
261
262
263
        self.qkv_proj = QKVParallelLinear(
            hidden_size=hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_kv_heads,
            bias=bias,
264
            quant_config=qkv_quant_config,
Chang Su's avatar
Chang Su committed
265
            prefix=add_prefix("qkv_proj", prefix),
fzyzcjy's avatar
fzyzcjy committed
266
267
            tp_rank=attn_tp_rank,
            tp_size=attn_tp_size,
Chang Su's avatar
Chang Su committed
268
269
270
271
272
273
        )

        self.o_proj = RowParallelLinear(
            input_size=self.total_num_heads * self.head_dim,
            output_size=hidden_size,
            bias=bias_o_proj,
274
            quant_config=o_quant_config,
Chang Su's avatar
Chang Su committed
275
            prefix=add_prefix("o_proj", prefix),
fzyzcjy's avatar
fzyzcjy committed
276
277
278
            tp_rank=attn_tp_rank,
            tp_size=attn_tp_size,
            reduce_results=False,
Chang Su's avatar
Chang Su committed
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
        )
        is_neox_style = True
        is_gguf = quant_config and quant_config.get_name() == "gguf"
        if is_gguf and config.model_type in ["llama", "llama4"]:
            is_neox_style = False

        self.rotary_emb = (
            get_rope(
                self.head_dim,
                rotary_dim=self.head_dim,
                max_position=max_position_embeddings,
                base=int(rope_theta),
                rope_scaling=rope_scaling if rope_scaling != "default" else None,
                is_neox_style=is_neox_style,
            )
            if self.use_rope
            else None
        )

        self.attn = RadixAttention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            layer_id=layer_id,
            prefix=add_prefix("attn", prefix),
            use_irope=self.use_rope,
        )

    def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
        floor = torch.floor((positions + 1.0) / self.floor_scale)
        attn_scale = torch.log(floor + 1.0) * self.attn_scale + 1.0
        return attn_scale.unsqueeze(-1)

fzyzcjy's avatar
fzyzcjy committed
313
314
315
316
317
    @torch.compile(dynamic=True, backend=get_compiler_backend())
    def _mul_attn_scale(self, positions, q):
        attn_scale = self._get_attn_scale(positions)
        return (q * attn_scale).to(q.dtype)

Chang Su's avatar
Chang Su committed
318
319
320
321
322
323
324
    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
fzyzcjy's avatar
fzyzcjy committed
325
326

        qk, v = qkv.split([self.q_size + self.kv_size, self.kv_size], dim=-1)
Chang Su's avatar
Chang Su committed
327
328

        if self.rotary_emb is not None:
fzyzcjy's avatar
fzyzcjy committed
329
330
331
            q_view, k_view = qk.split([self.q_size, self.kv_size], dim=-1)
            q_out_unused, k_out_unused = self.rotary_emb(positions, q_view, k_view)
            del q_view, k_view, q_out_unused, k_out_unused
Chang Su's avatar
Chang Su committed
332
333

        if self.qk_norm is not None:
fzyzcjy's avatar
fzyzcjy committed
334
335
336
337
338
339
            # TODO there are still 2 redundant direct_copy_kernel_cuda for this `reshape` and (in attn backend) q.contiguous(), maybe we can fuse them later
            qk = qk.reshape(-1, self.head_dim).contiguous().bfloat16()
            qk = self.qk_norm(qk).to(torch.bfloat16)
            qk = qk.reshape(-1, self.q_size + self.kv_size)

        q, k = qk.split([self.q_size, self.kv_size], dim=-1)
Chang Su's avatar
Chang Su committed
340
341
342
343
344
345

        # We are applying temperature tuning (https://arxiv.org/abs/2501.19399) to NoPE layers, where
        # the inference-time temperature tuning function is customized to not affect short context
        # while working at very long context
        # https://arxiv.org/abs/2501.19399
        if self.attn_temperature_tuning and not self.use_rope:
fzyzcjy's avatar
fzyzcjy committed
346
            q = self._mul_attn_scale(positions=positions, q=q)
Chang Su's avatar
Chang Su committed
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366

        attn_output = self.attn(q, k, v, forward_batch)
        output, _ = self.o_proj(attn_output)
        return output


class Llama4DecoderLayer(nn.Module):
    def __init__(
        self,
        config: Llama4TextConfig,
        layer_id: int = 0,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
        super().__init__()
        self.layer_id = layer_id
        self.hidden_size = config.hidden_size
        rope_theta = config.rope_theta
        rope_scaling = config.rope_scaling
        max_position_embeddings = config.max_position_embeddings
367
        self.local_dp_size = get_local_attention_dp_size()
fzyzcjy's avatar
fzyzcjy committed
368
369
        self.attn_tp_size = get_attention_tp_size()
        self.attn_tp_rank = get_attention_tp_rank()
Chang Su's avatar
Chang Su committed
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384

        self.self_attn = Llama4Attention(
            config=config,
            layer_id=layer_id,
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
            quant_config=quant_config,
            bias=False,
            bias_o_proj=False,
            prefix=add_prefix("self_attn", prefix),
        )
385
386
387
388
        self.config = config
        is_moe_layer = self._is_moe_layer(layer_id)
        is_previous_moe_layer = self._is_moe_layer(layer_id - 1)

Chang Su's avatar
Chang Su committed
389
390
391
        if is_moe_layer:
            self.feed_forward = Llama4MoE(
                config=config,
Cheng Wan's avatar
Cheng Wan committed
392
                layer_id=layer_id,
Chang Su's avatar
Chang Su committed
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
                quant_config=quant_config,
                prefix=add_prefix("feed_forward", prefix),
            )
        else:
            self.feed_forward = LlamaMLP(
                hidden_size=self.hidden_size,
                intermediate_size=config.intermediate_size_mlp,
                hidden_act="silu",
                quant_config=quant_config,
                prefix=add_prefix("feed_forward", prefix),
            )
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )

409
410
411
412
413
414
415
416
417
418
419
        self.layer_scatter_modes = LayerScatterModes.init_new(
            layer_id=layer_id,
            num_layers=config.num_hidden_layers,
            is_layer_sparse=is_moe_layer,
            is_previous_layer_sparse=is_previous_moe_layer,
        )

        self.layer_communicator = LayerCommunicator(
            layer_scatter_modes=self.layer_scatter_modes,
            input_layernorm=self.input_layernorm,
            post_attention_layernorm=self.post_attention_layernorm,
420
            allow_reduce_scatter=True,
421
422
423
        )

    def _is_moe_layer(self, layer_id: int) -> bool:
424
425
        if self.config.interleave_moe_layer_step == 0:
            return self.config.num_local_experts > 0
426
427
        return (layer_id + 1) % self.config.interleave_moe_layer_step == 0

Chang Su's avatar
Chang Su committed
428
429
430
431
432
433
434
    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
435
436
437
438
439
        hidden_states, residual = self.layer_communicator.prepare_attn(
            hidden_states, residual, forward_batch
        )

        if hidden_states.shape[0] != 0:
fzyzcjy's avatar
fzyzcjy committed
440
441
442
443
444
445
            hidden_states = self.self_attn(
                positions=positions,
                hidden_states=hidden_states,
                forward_batch=forward_batch,
            )

446
447
448
        hidden_states, residual = self.layer_communicator.prepare_mlp(
            hidden_states, residual, forward_batch
        )
Chang Su's avatar
Chang Su committed
449

450
451
452
453
454
        # For DP with padding, reduce scatter can be used instead of all-reduce.
        use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
            forward_batch
        )

Chang Su's avatar
Chang Su committed
455
        # Fully Connected
456
457
458
        hidden_states = self.feed_forward(
            hidden_states, forward_batch, use_reduce_scatter
        )
459
460
461
        hidden_states, residual = self.layer_communicator.postprocess_layer(
            hidden_states, residual, forward_batch
        )
fzyzcjy's avatar
fzyzcjy committed
462

Chang Su's avatar
Chang Su committed
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
        return hidden_states, residual


class Llama4Model(nn.Module):
    def __init__(
        self,
        config: Llama4TextConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.config = config
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=add_prefix("embed_tokens", prefix),
482
            enable_tp=not is_dp_attention_enabled(),
Chang Su's avatar
Chang Su committed
483
484
485
486
487
488
        )
        self.layers = make_layers(
            config.num_hidden_layers,
            lambda idx, prefix: Llama4DecoderLayer(
                config=config, layer_id=idx, quant_config=quant_config, prefix=prefix
            ),
489
            prefix=add_prefix("layers", prefix),
Chang Su's avatar
Chang Su committed
490
491
492
493
494
495
496
497
498
499
500
        )

        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.layers_to_capture = []

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        forward_batch: ForwardBatch,
        input_embeds: torch.Tensor = None,
501
        pp_proxy_tensors: Optional[PPProxyTensors] = None,
Chang Su's avatar
Chang Su committed
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
        if input_embeds is None:
            hidden_states = self.embed_tokens(input_ids)
        else:
            hidden_states = input_embeds
        residual = None
        aux_hidden_states = []
        for i in range(len(self.layers)):
            if i in self.layers_to_capture:
                aux_hidden_states.append(hidden_states + residual)
            layer = self.layers[i]
            hidden_states, residual = layer(
                positions,
                hidden_states,
                forward_batch,
                residual,
            )
fzyzcjy's avatar
fzyzcjy committed
519
520
        if not forward_batch.forward_mode.is_idle():
            hidden_states, _ = self.norm(hidden_states, residual)
Chang Su's avatar
Chang Su committed
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541

        if len(aux_hidden_states) == 0:
            return hidden_states

        return hidden_states, aux_hidden_states


class Llama4ForCausalLM(LlamaForCausalLM):
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"],
    }

    def __init__(
        self,
        config: Llama4TextConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
        super().__init__(config, quant_config, prefix)

Mick's avatar
Mick committed
542
543
544
    def get_input_embeddings(self):
        return self.model.embed_tokens

Chang Su's avatar
Chang Su committed
545
546
547
548
549
550
551
552
553
554
    def _init_model(
        self,
        config: Llama4TextConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
        return Llama4Model(config, quant_config=quant_config, prefix=prefix)


EntryClass = [Llama4ForCausalLM]