llama.py 27.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14

Lianmin Zheng's avatar
Lianmin Zheng committed
15
# Adapted from
16
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1
Lianmin Zheng's avatar
Lianmin Zheng committed
17
"""Inference-only LLaMA model compatible with HuggingFace weights."""
18

19
import logging
20
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
21
22
23
24
25

import torch
from torch import nn
from transformers import LlamaConfig

26
from sglang.srt.distributed import (
27
    get_pp_group,
28
29
30
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
)
31
32
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
33
34
35
36
37
from sglang.srt.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
Liangsheng Yin's avatar
Liangsheng Yin committed
38
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
39
from sglang.srt.layers.pooler import Pooler, PoolingType
40
from sglang.srt.layers.quantization.base_config import QuantizationConfig
Liangsheng Yin's avatar
Liangsheng Yin committed
41
from sglang.srt.layers.radix_attention import RadixAttention
42
from sglang.srt.layers.rotary_embedding import get_rope
43
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
44
45
46
47
from sglang.srt.layers.vocab_parallel_embedding import (
    ParallelLMHead,
    VocabParallelEmbedding,
)
48
from sglang.srt.managers.schedule_batch import global_server_args_dict
49
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
50
51
52
from sglang.srt.model_loader.weight_utils import (
    default_weight_loader,
    kv_cache_scales_loader,
53
    maybe_remap_kv_scale_name,
54
)
55
from sglang.srt.utils import add_prefix, make_layers
56
from sglang.utils import get_exception_traceback
Liangsheng Yin's avatar
Liangsheng Yin committed
57

58
59
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
60
61
62
63
64
65
66

class LlamaMLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
67
        quant_config: Optional[QuantizationConfig] = None,
Yineng Zhang's avatar
Yineng Zhang committed
68
        prefix: str = "",
Chang Su's avatar
Chang Su committed
69
        reduce_results: bool = True,
Lianmin Zheng's avatar
Lianmin Zheng committed
70
71
72
73
74
75
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size,
            [intermediate_size] * 2,
            bias=False,
76
            quant_config=quant_config,
77
            prefix=add_prefix("gate_up_proj", prefix),
Lianmin Zheng's avatar
Lianmin Zheng committed
78
79
        )
        self.down_proj = RowParallelLinear(
Yuanhan Zhang's avatar
Yuanhan Zhang committed
80
81
82
83
            intermediate_size,
            hidden_size,
            bias=False,
            quant_config=quant_config,
84
            prefix=add_prefix("down_proj", prefix),
Chang Su's avatar
Chang Su committed
85
            reduce_results=reduce_results,
Lianmin Zheng's avatar
Lianmin Zheng committed
86
87
88
89
90
91
92
93
        )
        if hidden_act != "silu":
            raise ValueError(
                f"Unsupported activation: {hidden_act}. "
                "Only silu is supported for now."
            )
        self.act_fn = SiluAndMul()

94
    def forward(self, x, forward_batch=None):
Lianmin Zheng's avatar
Lianmin Zheng committed
95
96
97
98
99
100
101
102
103
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class LlamaAttention(nn.Module):
    def __init__(
        self,
Ke Bao's avatar
Ke Bao committed
104
        config: LlamaConfig,
Lianmin Zheng's avatar
Lianmin Zheng committed
105
106
107
108
109
110
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        layer_id: int = 0,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
111
        rope_is_neox_style: bool = True,
Lianmin Zheng's avatar
Lianmin Zheng committed
112
        max_position_embeddings: int = 8192,
113
        quant_config: Optional[QuantizationConfig] = None,
Yineng Zhang's avatar
Yineng Zhang committed
114
        prefix: str = "",
115
        bias: bool = False,
Lianmin Zheng's avatar
Lianmin Zheng committed
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
Ke Bao's avatar
Ke Bao committed
133
134
135
136
        # MistralConfig has an optional head_dim introduced by Mistral-Nemo
        self.head_dim = getattr(
            config, "head_dim", self.hidden_size // self.total_num_heads
        )
137
138
        partial_rotary_factor = getattr(config, "partial_rotary_factor", 1)
        self.rotary_dim = int(partial_rotary_factor * self.head_dim)
Lianmin Zheng's avatar
Lianmin Zheng committed
139
140
141
142
143
144
145
146
147
148
149
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
150
            bias=bias,
151
            quant_config=quant_config,
152
            prefix=add_prefix("qkv_proj", prefix),
Lianmin Zheng's avatar
Lianmin Zheng committed
153
154
155
156
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
157
            bias=bias,
158
            quant_config=quant_config,
159
            prefix=add_prefix("o_proj", prefix),
Lianmin Zheng's avatar
Lianmin Zheng committed
160
161
162
163
        )

        self.rotary_emb = get_rope(
            self.head_dim,
164
            rotary_dim=self.rotary_dim,
Lianmin Zheng's avatar
Lianmin Zheng committed
165
166
167
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
168
            is_neox_style=rope_is_neox_style,
Lianmin Zheng's avatar
Lianmin Zheng committed
169
170
171
172
173
174
175
        )
        self.attn = RadixAttention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            layer_id=layer_id,
176
            quant_config=quant_config,
177
            prefix=add_prefix("attn", prefix),
Lianmin Zheng's avatar
Lianmin Zheng committed
178
179
180
181
182
183
        )

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
184
        forward_batch: ForwardBatch,
Lianmin Zheng's avatar
Lianmin Zheng committed
185
186
187
188
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
189
        attn_output = self.attn(q, k, v, forward_batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
190
191
192
193
194
195
196
197
198
        output, _ = self.o_proj(attn_output)
        return output


class LlamaDecoderLayer(nn.Module):
    def __init__(
        self,
        config: LlamaConfig,
        layer_id: int = 0,
199
        quant_config: Optional[QuantizationConfig] = None,
Yineng Zhang's avatar
Yineng Zhang committed
200
        prefix: str = "",
Lianmin Zheng's avatar
Lianmin Zheng committed
201
202
203
204
205
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
206
        if rope_scaling is not None and getattr(
207
208
            config, "original_max_position_embeddings", None
        ):
zhyncs's avatar
zhyncs committed
209
210
211
            rope_scaling["original_max_position_embeddings"] = (
                config.original_max_position_embeddings
            )
212
        rope_is_neox_style = getattr(config, "rope_is_neox_style", True)
Lianmin Zheng's avatar
Lianmin Zheng committed
213
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
214
215
216
217
218
        # Support llamafy/Qwen-Qwen2.5-7B-Instruct-llamafied with attention_bias
        # Support internlm/internlm-7b with bias
        attention_bias = getattr(config, "attention_bias", False) or getattr(
            config, "bias", False
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
219
        self.self_attn = LlamaAttention(
Ke Bao's avatar
Ke Bao committed
220
            config=config,
Lianmin Zheng's avatar
Lianmin Zheng committed
221
222
223
224
225
226
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            layer_id=layer_id,
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
227
            rope_is_neox_style=rope_is_neox_style,
Lianmin Zheng's avatar
Lianmin Zheng committed
228
            max_position_embeddings=max_position_embeddings,
229
            quant_config=quant_config,
230
            prefix=add_prefix("self_attn", prefix),
231
            bias=attention_bias,
Lianmin Zheng's avatar
Lianmin Zheng committed
232
233
234
235
236
        )
        self.mlp = LlamaMLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
237
            quant_config=quant_config,
238
            prefix=add_prefix("mlp", prefix),
Lianmin Zheng's avatar
Lianmin Zheng committed
239
240
241
242
243
244
245
246
247
248
        )
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
249
        forward_batch: ForwardBatch,
Lianmin Zheng's avatar
Lianmin Zheng committed
250
251
252
253
254
255
256
257
258
259
260
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
261
            forward_batch=forward_batch,
Lianmin Zheng's avatar
Lianmin Zheng committed
262
263
264
265
266
267
268
269
270
271
272
273
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


class LlamaModel(nn.Module):
    def __init__(
        self,
        config: LlamaConfig,
274
        quant_config: Optional[QuantizationConfig] = None,
275
        prefix: str = "",
Lianmin Zheng's avatar
Lianmin Zheng committed
276
277
278
279
280
    ) -> None:
        super().__init__()
        self.config = config
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size
281
282
283
284
285
286
287
288
289
290
291
292
        self.pp_group = get_pp_group()
        if self.pp_group.is_first_rank:
            self.embed_tokens = VocabParallelEmbedding(
                config.vocab_size,
                config.hidden_size,
                quant_config=quant_config,
                prefix=add_prefix("embed_tokens", prefix),
            )
        else:
            self.embed_tokens = PPMissingLayer()

        self.layers, self.start_layer, self.end_layer = make_layers(
293
294
            config.num_hidden_layers,
            lambda idx, prefix: LlamaDecoderLayer(
295
                config=config, quant_config=quant_config, layer_id=idx, prefix=prefix
296
            ),
297
298
            pp_rank=self.pp_group.rank_in_group,
            pp_size=self.pp_group.world_size,
299
            prefix="model.layers",
Lianmin Zheng's avatar
Lianmin Zheng committed
300
        )
301

302
303
304
305
        if self.pp_group.is_last_rank:
            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        else:
            self.norm = PPMissingLayer(return_tuple=True)
James Liu's avatar
James Liu committed
306
        self.layers_to_capture = []
Lianmin Zheng's avatar
Lianmin Zheng committed
307
308
309
310
311

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
312
        forward_batch: ForwardBatch,
313
        input_embeds: torch.Tensor = None,
314
315
316
317
318
319
320
321
        pp_proxy_tensors: Optional[PPProxyTensors] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]], PPProxyTensors]:
        if self.pp_group.is_first_rank:
            if input_embeds is None:
                hidden_states = self.embed_tokens(input_ids)
            else:
                hidden_states = input_embeds
            residual = None
Lianmin Zheng's avatar
Lianmin Zheng committed
322
        else:
323
324
325
326
327
328
            assert pp_proxy_tensors is not None
            # FIXME(@ying): reduce the number of proxy tensors by not fusing layer norms
            hidden_states = pp_proxy_tensors["hidden_states"]
            residual = pp_proxy_tensors["residual"]
            deferred_norm = None

James Liu's avatar
James Liu committed
329
        aux_hidden_states = []
330
        for i in range(self.start_layer, self.end_layer):
James Liu's avatar
James Liu committed
331
332
            if i in self.layers_to_capture:
                aux_hidden_states.append(hidden_states + residual)
Lianmin Zheng's avatar
Lianmin Zheng committed
333
334
335
336
            layer = self.layers[i]
            hidden_states, residual = layer(
                positions,
                hidden_states,
337
                forward_batch,
Lianmin Zheng's avatar
Lianmin Zheng committed
338
339
                residual,
            )
340
341
342
343
344
345
346
347
348
349

        if not self.pp_group.is_last_rank:
            return PPProxyTensors(
                {
                    "hidden_states": hidden_states,
                    "residual": residual,
                }
            )
        else:
            hidden_states, _ = self.norm(hidden_states, residual)
James Liu's avatar
James Liu committed
350
351
352
353
354

        if len(aux_hidden_states) == 0:
            return hidden_states

        return hidden_states, aux_hidden_states
Lianmin Zheng's avatar
Lianmin Zheng committed
355

bjmsong's avatar
bjmsong committed
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
    # If this function is called, it should always initialize KV cache scale
    # factors (or else raise an exception). Thus, handled exceptions should
    # make sure to leave KV cache scale factors in a known good (dummy) state
    def load_kv_cache_scales(self, quantization_param_path: str) -> None:
        tp_size = get_tensor_model_parallel_world_size()
        tp_rank = get_tensor_model_parallel_rank()
        for layer_idx, scaling_factor in kv_cache_scales_loader(
            quantization_param_path,
            tp_rank,
            tp_size,
            self.config.num_hidden_layers,
            self.config.__class__.model_type,
        ):
            if not isinstance(self.layers[layer_idx], nn.Identity):
                layer_self_attn = self.layers[layer_idx].self_attn

            if hasattr(layer_self_attn.attn, "k_scale"):
                layer_self_attn.attn.k_scale = scaling_factor
                layer_self_attn.attn.v_scale = scaling_factor
            else:
                raise RuntimeError(
                    "Self attention has no KV cache scaling " "factor attribute!"
                )

Lianmin Zheng's avatar
Lianmin Zheng committed
380
381

class LlamaForCausalLM(nn.Module):
382
383
384
385
386
387
388
389
390
391
392
393
394
395
    # BitandBytes specific attributes
    default_bitsandbytes_target_modules = [
        ".gate_proj.",
        ".down_proj.",
        ".up_proj.",
        ".q_proj.",
        ".k_proj.",
        ".v_proj.",
        ".o_proj.",
    ]
    # in TP, these weights are partitioned along the column dimension (dim=-1)
    column_parallel_weights_modules = [".down_proj.", ".o_proj."]
    bitsandbytes_stacked_params_mapping = {
        # shard_name, weight_name, index
396
397
398
399
400
        ".q_proj": (".qkv_proj", 0),
        ".k_proj": (".qkv_proj", 1),
        ".v_proj": (".qkv_proj", 2),
        ".gate_proj": (".gate_up_proj", 0),
        ".up_proj": (".gate_up_proj", 1),
401
402
    }

Lianmin Zheng's avatar
Lianmin Zheng committed
403
404
405
    def __init__(
        self,
        config: LlamaConfig,
406
        quant_config: Optional[QuantizationConfig] = None,
407
        prefix: str = "",
Lianmin Zheng's avatar
Lianmin Zheng committed
408
409
    ) -> None:
        super().__init__()
410
        self.pp_group = get_pp_group()
Lianmin Zheng's avatar
Lianmin Zheng committed
411
        self.config = config
412
        self.quant_config = quant_config
Chang Su's avatar
Chang Su committed
413
        self.model = self._init_model(config, quant_config, add_prefix("model", prefix))
fzyzcjy's avatar
fzyzcjy committed
414
415
        # Llama 3.2 1B Instruct set tie_word_embeddings to True
        # Llama 3.1 8B Instruct set tie_word_embeddings to False
416
417
418
419
        if self.config.tie_word_embeddings:
            self.lm_head = self.model.embed_tokens
        else:
            self.lm_head = ParallelLMHead(
420
421
422
423
                config.vocab_size,
                config.hidden_size,
                quant_config=quant_config,
                prefix=add_prefix("lm_head", prefix),
424
                use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
425
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
426
        self.logits_processor = LogitsProcessor(config)
427
        self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
428
429
430
431
432
433
434
435
        self.stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
        ]
Lianmin Zheng's avatar
Lianmin Zheng committed
436

James Liu's avatar
James Liu committed
437
438
        self.capture_aux_hidden_states = False

Chang Su's avatar
Chang Su committed
439
440
441
442
443
444
445
446
    def _init_model(
        self,
        config: LlamaConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
        return LlamaModel(config, quant_config=quant_config, prefix=prefix)

Liangsheng Yin's avatar
Liangsheng Yin committed
447
    @torch.no_grad()
Lianmin Zheng's avatar
Lianmin Zheng committed
448
449
450
451
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
452
        forward_batch: ForwardBatch,
453
        input_embeds: torch.Tensor = None,
454
        get_embedding: bool = False,
455
        pp_proxy_tensors: Optional[PPProxyTensors] = None,
Liangsheng Yin's avatar
Liangsheng Yin committed
456
    ) -> LogitsProcessorOutput:
457
458
459
460
461
462
463
464
        hidden_states = self.model(
            input_ids,
            positions,
            forward_batch,
            input_embeds,
            pp_proxy_tensors=pp_proxy_tensors,
        )

James Liu's avatar
James Liu committed
465
466
        aux_hidden_states = None
        if self.capture_aux_hidden_states:
467
468
469
470
471
472
473
474
475
476
477
478
479
            hidden_states, aux_hidden_states = hidden_states

        if self.pp_group.is_last_rank:
            if not get_embedding:
                return self.logits_processor(
                    input_ids,
                    hidden_states,
                    self.lm_head,
                    forward_batch,
                    aux_hidden_states,
                )
            else:
                return self.pooler(hidden_states, forward_batch)
James Liu's avatar
James Liu committed
480
        else:
481
            return hidden_states
James Liu's avatar
James Liu committed
482

483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
    @torch.no_grad()
    def forward_split_prefill(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        forward_batch: ForwardBatch,
        split_interval: Tuple[int, int],  # [start, end) 0-based
        input_embeds: torch.Tensor = None,
    ) -> Optional[LogitsProcessorOutput]:
        start, end = split_interval
        # embed
        if start == 0:
            if input_embeds is None:
                forward_batch.hidden_states = self.model.embed_tokens(input_ids)
            else:
                forward_batch.hidden_states = input_embeds
        # decoder layer
        for i in range(start, end):
            layer = self.model.layers[i]
            forward_batch.hidden_states, forward_batch.residual = layer(
                positions,
                forward_batch.hidden_states,
                forward_batch,
                forward_batch.residual,
            )

        if end == self.model.config.num_hidden_layers:
            # norm
            hidden_states, _ = self.model.norm(
                forward_batch.hidden_states, forward_batch.residual
            )
            forward_batch.hidden_states = hidden_states
            # logits process
            result = self.logits_processor(
                input_ids, forward_batch.hidden_states, self.lm_head, forward_batch
            )
        else:
            result = None

        return result

524
525
526
527
528
529
530
    @property
    def start_layer(self):
        return self.model.start_layer

    @property
    def end_layer(self):
        return self.model.end_layer
Lianmin Zheng's avatar
Lianmin Zheng committed
531

Mick's avatar
Mick committed
532
533
534
    def get_input_embeddings(self) -> nn.Embedding:
        return self.model.embed_tokens

535
    def get_module_name_from_weight_name(self, name):
536
        for param_name, weight_name, shard_id, num_shard in self.stacked_params_mapping:
537
538
539
540
541
542
543
544
545
546
547
            if weight_name in name:
                return (
                    name.replace(weight_name, param_name)[: -len(".weight")],
                    num_shard,
                )
        return name[: -len(".weight")], 1

    def get_num_params(self):
        params_dict = dict(self.named_parameters())
        return len(params_dict)

548
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
Lianmin Zheng's avatar
Lianmin Zheng committed
549
550
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
551
552
553
554
555
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
Lianmin Zheng's avatar
Lianmin Zheng committed
556
        ]
557

558
        params_dict = dict(self.named_parameters())
559

560
        for name, loaded_weight in weights:
561
562
563
564
565
566
567
568
569
570
            layer_id = get_layer_id(name)
            if (
                layer_id is not None
                and hasattr(self.model, "start_layer")
                and (
                    layer_id < self.model.start_layer
                    or layer_id >= self.model.end_layer
                )
            ):
                continue
Lianmin Zheng's avatar
Lianmin Zheng committed
571
            if "rotary_emb.inv_freq" in name or "projector" in name:
572
                continue
Lianmin Zheng's avatar
Lianmin Zheng committed
573
574
575
            if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
576
                continue
577
            if name.startswith("model.vision_tower") and name not in params_dict:
578
                continue
579
580
            if self.config.tie_word_embeddings and "lm_head.weight" in name:
                continue
581
582
583
584
585
            # Handle FP8 kv-scale remapping
            if "scale" in name:
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue
586

Lianmin Zheng's avatar
Lianmin Zheng committed
587
588
589
590
591
592
593
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
594
595
                if name not in params_dict:
                    continue
Lianmin Zheng's avatar
Lianmin Zheng committed
596
597
598
599
600
601
602
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
603
                    continue
604
605
606
                # Skip loading kv_scale from ckpts towards new design.
                if name.endswith(".kv_scale") and name not in params_dict:
                    continue
607
608
609
610
611
612
613
614
                if name in params_dict.keys():
                    param = params_dict[name]
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
                    weight_loader(param, loaded_weight)
                else:
                    logger.warning(f"Parameter {name} not found in params_dict")
Cody Yu's avatar
Cody Yu committed
615

616
617
618
619
620
621
622
623
624
    def get_weights_by_name(
        self, name: str, truncate_size: int = 100, tp_size: int = 1
    ) -> Optional[torch.Tensor]:
        """Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face.

        Only used for unit test with an unoptimized performance.
        For optimized performance, please use torch.save and torch.load.
        """
        try:
625
626
627
628
629
630
631
632
633
634
635
            if name == "lm_head.weight" and self.config.tie_word_embeddings:
                logger.info(
                    "word embedding is tied for this model, return embed_tokens.weight as lm_head.weight."
                )
                return (
                    self.model.embed_tokens.weight.cpu()
                    .to(torch.float32)
                    .numpy()
                    .tolist()[:truncate_size]
                )

636
637
638
639
640
641
642
643
            mapped_name = name
            mapped_shard_id = None
            for param_name, weight_name, shard_id in self.stacked_params_mapping:
                if weight_name in name:
                    mapped_name = name.replace(weight_name, param_name)
                    mapped_shard_id = shard_id
                    break
            params_dict = dict(self.named_parameters())
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
            param = params_dict[mapped_name]
            if mapped_shard_id is not None:
                if mapped_shard_id in ["q", "k", "v"]:
                    num_heads = self.config.num_attention_heads // tp_size
                    num_kv_heads = self.config.num_key_value_heads // tp_size
                    head_dim = (
                        self.config.hidden_size // self.config.num_attention_heads
                    )
                    if mapped_shard_id == "q":
                        offset = 0
                        size = num_heads * head_dim
                    elif mapped_shard_id == "k":
                        offset = num_heads * head_dim
                        size = num_kv_heads * head_dim
                    elif mapped_shard_id == "v":
                        offset = (num_heads + num_kv_heads) * head_dim
                        size = num_kv_heads * head_dim
                    weight = param.data.narrow(0, offset, size)
                elif mapped_shard_id in [0, 1]:
                    intermediate_size = self.config.intermediate_size
                    slice_size = intermediate_size // tp_size
                    if mapped_shard_id == 0:  # gate_proj
                        offset = 0
                        size = slice_size
                    elif mapped_shard_id == 1:  # up_proj
                        offset = slice_size
                        size = slice_size

                    weight = param.data.narrow(0, offset, size)
673
674
675
                else:
                    weight = param.data
            else:
676
677
678
679
680
681
682
683
                weight = param.data
            if tp_size > 1 and ("o_proj" in name or "down_proj" in name):
                gathered_weights = [torch.zeros_like(weight) for _ in range(tp_size)]
                torch.distributed.all_gather(gathered_weights, weight)
                weight = torch.cat(gathered_weights, dim=1)
            return weight.cpu().to(torch.float32).numpy().tolist()[:truncate_size]

        except Exception:
684
            logger.error(
685
                f"Error getting weights by name {name} in LlamaForCausalLM: {get_exception_traceback()}"
686
687
688
            )
            return None

689
690
691
692
693
694
695
696
697
698
699
    def get_embed_and_head(self):
        return self.model.embed_tokens.weight, self.lm_head.weight

    def set_embed_and_head(self, embed, head):
        del self.model.embed_tokens.weight
        del self.lm_head.weight
        self.model.embed_tokens.weight = embed
        self.lm_head.weight = head
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

James Liu's avatar
James Liu committed
700
701
702
703
    def get_embed(self):
        return self.model.embed_tokens.weight

    def set_embed(self, embed):
Ke Bao's avatar
Ke Bao committed
704
705
706
707
708
709
        # NOTE: If draft hidden size != target hidden size, the embed weight cannot be shared for EAGLE3
        if (
            hasattr(self.config, "target_hidden_size")
            and self.config.target_hidden_size != self.config.hidden_size
        ):
            return
James Liu's avatar
James Liu committed
710
711
712
713
714
        del self.model.embed_tokens.weight
        self.model.embed_tokens.weight = embed
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

bjmsong's avatar
bjmsong committed
715
716
717
    def load_kv_cache_scales(self, quantization_param_path: str) -> None:
        self.model.load_kv_cache_scales(quantization_param_path)

lukec's avatar
lukec committed
718
    def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
719
720
721
        if not self.pp_group.is_last_rank:
            return

lukec's avatar
lukec committed
722
723
724
725
726
727
728
729
730
        if layer_ids is None:
            self.capture_aux_hidden_states = True
            num_layers = self.config.num_hidden_layers
            self.model.layers_to_capture = [2, num_layers // 2, num_layers - 3]
        else:
            self.capture_aux_hidden_states = True
            # we plus 1 here because in sglang, for the ith layer, it takes the output
            # of the (i-1)th layer as aux hidden state
            self.model.layers_to_capture = [val + 1 for val in layer_ids]
James Liu's avatar
James Liu committed
731

Lianmin Zheng's avatar
Lianmin Zheng committed
732

733
734
735
736
class Phi3ForCausalLM(LlamaForCausalLM):
    pass


737
738
739
740
741
class InternLM3ForCausalLM(LlamaForCausalLM):
    pass


EntryClass = [LlamaForCausalLM, Phi3ForCausalLM, InternLM3ForCausalLM]