llama.py 22.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14

Lianmin Zheng's avatar
Lianmin Zheng committed
15
# Adapted from
16
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1
Lianmin Zheng's avatar
Lianmin Zheng committed
17
"""Inference-only LLaMA model compatible with HuggingFace weights."""
18

19
import logging
20
from typing import Any, Dict, Iterable, Optional, Tuple
Lianmin Zheng's avatar
Lianmin Zheng committed
21
22
23
24
25

import torch
from torch import nn
from transformers import LlamaConfig

26
27
28
29
from sglang.srt.distributed import (
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
)
30
31
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
32
33
34
35
36
from sglang.srt.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
Liangsheng Yin's avatar
Liangsheng Yin committed
37
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
38
from sglang.srt.layers.pooler import Pooler, PoolingType
39
from sglang.srt.layers.quantization.base_config import QuantizationConfig
Liangsheng Yin's avatar
Liangsheng Yin committed
40
from sglang.srt.layers.radix_attention import RadixAttention
41
from sglang.srt.layers.rotary_embedding import get_rope
42
43
44
45
from sglang.srt.layers.vocab_parallel_embedding import (
    ParallelLMHead,
    VocabParallelEmbedding,
)
46
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
47
48
49
from sglang.srt.model_loader.weight_utils import (
    default_weight_loader,
    kv_cache_scales_loader,
50
    maybe_remap_kv_scale_name,
51
)
52
from sglang.srt.utils import add_prefix, make_layers
53
from sglang.utils import get_exception_traceback
Liangsheng Yin's avatar
Liangsheng Yin committed
54

55
56
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
57
58
59
60
61
62
63

class LlamaMLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
64
        quant_config: Optional[QuantizationConfig] = None,
Yineng Zhang's avatar
Yineng Zhang committed
65
        prefix: str = "",
Lianmin Zheng's avatar
Lianmin Zheng committed
66
67
68
69
70
71
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size,
            [intermediate_size] * 2,
            bias=False,
72
            quant_config=quant_config,
73
            prefix=add_prefix("gate_up_proj", prefix),
Lianmin Zheng's avatar
Lianmin Zheng committed
74
75
        )
        self.down_proj = RowParallelLinear(
Yuanhan Zhang's avatar
Yuanhan Zhang committed
76
77
78
79
            intermediate_size,
            hidden_size,
            bias=False,
            quant_config=quant_config,
80
            prefix=add_prefix("down_proj", prefix),
Lianmin Zheng's avatar
Lianmin Zheng committed
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
        )
        if hidden_act != "silu":
            raise ValueError(
                f"Unsupported activation: {hidden_act}. "
                "Only silu is supported for now."
            )
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class LlamaAttention(nn.Module):
    def __init__(
        self,
Ke Bao's avatar
Ke Bao committed
99
        config: LlamaConfig,
Lianmin Zheng's avatar
Lianmin Zheng committed
100
101
102
103
104
105
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        layer_id: int = 0,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
106
        rope_is_neox_style: bool = True,
Lianmin Zheng's avatar
Lianmin Zheng committed
107
        max_position_embeddings: int = 8192,
108
        quant_config: Optional[QuantizationConfig] = None,
Yineng Zhang's avatar
Yineng Zhang committed
109
        prefix: str = "",
110
        bias: bool = False,
Lianmin Zheng's avatar
Lianmin Zheng committed
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
Ke Bao's avatar
Ke Bao committed
128
129
130
131
        # MistralConfig has an optional head_dim introduced by Mistral-Nemo
        self.head_dim = getattr(
            config, "head_dim", self.hidden_size // self.total_num_heads
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
132
133
134
135
136
137
138
139
140
141
142
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
143
            bias=bias,
144
            quant_config=quant_config,
145
            prefix=add_prefix("qkv_proj", prefix),
Lianmin Zheng's avatar
Lianmin Zheng committed
146
147
148
149
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
150
            bias=bias,
151
            quant_config=quant_config,
152
            prefix=add_prefix("o_proj", prefix),
Lianmin Zheng's avatar
Lianmin Zheng committed
153
154
155
156
157
158
159
160
        )

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
161
            is_neox_style=rope_is_neox_style,
Lianmin Zheng's avatar
Lianmin Zheng committed
162
163
164
165
166
167
168
        )
        self.attn = RadixAttention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            layer_id=layer_id,
169
            prefix=add_prefix("attn", prefix),
Lianmin Zheng's avatar
Lianmin Zheng committed
170
171
172
173
174
175
        )

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
176
        forward_batch: ForwardBatch,
Lianmin Zheng's avatar
Lianmin Zheng committed
177
178
179
180
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
181
        attn_output = self.attn(q, k, v, forward_batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
182
183
184
185
186
187
188
189
190
        output, _ = self.o_proj(attn_output)
        return output


class LlamaDecoderLayer(nn.Module):
    def __init__(
        self,
        config: LlamaConfig,
        layer_id: int = 0,
191
        quant_config: Optional[QuantizationConfig] = None,
Yineng Zhang's avatar
Yineng Zhang committed
192
        prefix: str = "",
Lianmin Zheng's avatar
Lianmin Zheng committed
193
194
195
196
197
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
198
        if rope_scaling is not None and getattr(
199
200
            config, "original_max_position_embeddings", None
        ):
zhyncs's avatar
zhyncs committed
201
202
203
            rope_scaling["original_max_position_embeddings"] = (
                config.original_max_position_embeddings
            )
204
        rope_is_neox_style = getattr(config, "rope_is_neox_style", True)
Lianmin Zheng's avatar
Lianmin Zheng committed
205
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
206
207
208
209
210
        # Support llamafy/Qwen-Qwen2.5-7B-Instruct-llamafied with attention_bias
        # Support internlm/internlm-7b with bias
        attention_bias = getattr(config, "attention_bias", False) or getattr(
            config, "bias", False
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
211
        self.self_attn = LlamaAttention(
Ke Bao's avatar
Ke Bao committed
212
            config=config,
Lianmin Zheng's avatar
Lianmin Zheng committed
213
214
215
216
217
218
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            layer_id=layer_id,
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
219
            rope_is_neox_style=rope_is_neox_style,
Lianmin Zheng's avatar
Lianmin Zheng committed
220
            max_position_embeddings=max_position_embeddings,
221
            quant_config=quant_config,
222
            prefix=add_prefix("self_attn", prefix),
223
            bias=attention_bias,
Lianmin Zheng's avatar
Lianmin Zheng committed
224
225
226
227
228
        )
        self.mlp = LlamaMLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
229
            quant_config=quant_config,
230
            prefix=add_prefix("mlp", prefix),
Lianmin Zheng's avatar
Lianmin Zheng committed
231
232
233
234
235
236
237
238
239
240
        )
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
241
        forward_batch: ForwardBatch,
Lianmin Zheng's avatar
Lianmin Zheng committed
242
243
244
245
246
247
248
249
250
251
252
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
253
            forward_batch=forward_batch,
Lianmin Zheng's avatar
Lianmin Zheng committed
254
255
256
257
258
259
260
261
262
263
264
265
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


class LlamaModel(nn.Module):
    def __init__(
        self,
        config: LlamaConfig,
266
        quant_config: Optional[QuantizationConfig] = None,
267
        prefix: str = "",
Lianmin Zheng's avatar
Lianmin Zheng committed
268
269
270
271
272
273
274
275
    ) -> None:
        super().__init__()
        self.config = config
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
276
            quant_config=quant_config,
277
            prefix=add_prefix("embed_tokens", prefix),
Lianmin Zheng's avatar
Lianmin Zheng committed
278
        )
279
280
281
282
283
284
        self.layers = make_layers(
            config.num_hidden_layers,
            lambda idx, prefix: LlamaDecoderLayer(
                config=config, quant_config=quant_config, layer_id=idx, prefix=prefix
            ),
            prefix="model.layers",
Lianmin Zheng's avatar
Lianmin Zheng committed
285
        )
286

Lianmin Zheng's avatar
Lianmin Zheng committed
287
288
289
290
291
292
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
293
        forward_batch: ForwardBatch,
294
        input_embeds: torch.Tensor = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
295
    ) -> torch.Tensor:
296
        if input_embeds is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
297
298
            hidden_states = self.embed_tokens(input_ids)
        else:
299
            hidden_states = input_embeds
Lianmin Zheng's avatar
Lianmin Zheng committed
300
301
302
303
304
305
        residual = None
        for i in range(len(self.layers)):
            layer = self.layers[i]
            hidden_states, residual = layer(
                positions,
                hidden_states,
306
                forward_batch,
Lianmin Zheng's avatar
Lianmin Zheng committed
307
308
309
310
311
                residual,
            )
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states

bjmsong's avatar
bjmsong committed
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
    # If this function is called, it should always initialize KV cache scale
    # factors (or else raise an exception). Thus, handled exceptions should
    # make sure to leave KV cache scale factors in a known good (dummy) state
    def load_kv_cache_scales(self, quantization_param_path: str) -> None:
        tp_size = get_tensor_model_parallel_world_size()
        tp_rank = get_tensor_model_parallel_rank()
        for layer_idx, scaling_factor in kv_cache_scales_loader(
            quantization_param_path,
            tp_rank,
            tp_size,
            self.config.num_hidden_layers,
            self.config.__class__.model_type,
        ):
            if not isinstance(self.layers[layer_idx], nn.Identity):
                layer_self_attn = self.layers[layer_idx].self_attn

            if hasattr(layer_self_attn.attn, "k_scale"):
                layer_self_attn.attn.k_scale = scaling_factor
                layer_self_attn.attn.v_scale = scaling_factor
            else:
                raise RuntimeError(
                    "Self attention has no KV cache scaling " "factor attribute!"
                )

Lianmin Zheng's avatar
Lianmin Zheng committed
336
337

class LlamaForCausalLM(nn.Module):
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359

    # BitandBytes specific attributes
    default_bitsandbytes_target_modules = [
        ".gate_proj.",
        ".down_proj.",
        ".up_proj.",
        ".q_proj.",
        ".k_proj.",
        ".v_proj.",
        ".o_proj.",
    ]
    # in TP, these weights are partitioned along the column dimension (dim=-1)
    column_parallel_weights_modules = [".down_proj.", ".o_proj."]
    bitsandbytes_stacked_params_mapping = {
        # shard_name, weight_name, index
        "q_proj": ("qkv_proj", 0),
        "k_proj": ("qkv_proj", 1),
        "v_proj": ("qkv_proj", 2),
        "gate_proj": ("gate_up_proj", 0),
        "up_proj": ("gate_up_proj", 1),
    }

Lianmin Zheng's avatar
Lianmin Zheng committed
360
361
362
    def __init__(
        self,
        config: LlamaConfig,
363
        quant_config: Optional[QuantizationConfig] = None,
364
        prefix: str = "",
Lianmin Zheng's avatar
Lianmin Zheng committed
365
366
367
    ) -> None:
        super().__init__()
        self.config = config
368
        self.quant_config = quant_config
369
370
371
        self.model = LlamaModel(
            config, quant_config=quant_config, prefix=add_prefix("model", prefix)
        )
fzyzcjy's avatar
fzyzcjy committed
372
373
        # Llama 3.2 1B Instruct set tie_word_embeddings to True
        # Llama 3.1 8B Instruct set tie_word_embeddings to False
374
375
376
377
        if self.config.tie_word_embeddings:
            self.lm_head = self.model.embed_tokens
        else:
            self.lm_head = ParallelLMHead(
378
379
380
381
                config.vocab_size,
                config.hidden_size,
                quant_config=quant_config,
                prefix=add_prefix("lm_head", prefix),
382
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
383
        self.logits_processor = LogitsProcessor(config)
384
        self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
385
386
387
388
389
390
391
392
        self.stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
        ]
Lianmin Zheng's avatar
Lianmin Zheng committed
393

Liangsheng Yin's avatar
Liangsheng Yin committed
394
    @torch.no_grad()
Lianmin Zheng's avatar
Lianmin Zheng committed
395
396
397
398
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
399
        forward_batch: ForwardBatch,
400
        input_embeds: torch.Tensor = None,
401
        get_embedding: bool = False,
Liangsheng Yin's avatar
Liangsheng Yin committed
402
    ) -> LogitsProcessorOutput:
403
        hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
404
405
        if not get_embedding:
            return self.logits_processor(
406
                input_ids, hidden_states, self.lm_head, forward_batch
407
408
409
            )
        else:
            return self.pooler(hidden_states, forward_batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
410

411
    def get_hidden_dim(self, module_name):
412
        # return input_dim, output_dim
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
        if module_name in ["q_proj", "o_proj", "qkv_proj"]:
            return self.config.hidden_size, self.config.hidden_size
        elif module_name in ["kv_proj"]:
            return self.config.hidden_size, self.config.hidden_size // (
                self.config.num_attention_heads // self.config.num_key_value_heads
            )
        elif module_name == "gate_up_proj":
            return self.config.hidden_size, self.config.intermediate_size
        elif module_name == "down_proj":
            return self.config.intermediate_size, self.config.hidden_size
        else:
            raise NotImplementedError()

    def get_module_name(self, name):
        params_mapping = {
            "q_proj": "qkv_proj",
            "k_proj": "qkv_proj",
            "v_proj": "qkv_proj",
            "gate_proj": "gate_up_proj",
            "up_proj": "gate_up_proj",
        }
        return params_mapping.get(name, name)

    def get_module_name_from_weight_name(self, name):
437
        for param_name, weight_name, shard_id, num_shard in self.stacked_params_mapping:
438
439
440
441
442
443
444
445
446
447
448
            if weight_name in name:
                return (
                    name.replace(weight_name, param_name)[: -len(".weight")],
                    num_shard,
                )
        return name[: -len(".weight")], 1

    def get_num_params(self):
        params_dict = dict(self.named_parameters())
        return len(params_dict)

449
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
Lianmin Zheng's avatar
Lianmin Zheng committed
450
451
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
452
453
454
455
456
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
Lianmin Zheng's avatar
Lianmin Zheng committed
457
        ]
458

459
        params_dict = dict(self.named_parameters())
460

461
        for name, loaded_weight in weights:
Lianmin Zheng's avatar
Lianmin Zheng committed
462
            if "rotary_emb.inv_freq" in name or "projector" in name:
463
                continue
Lianmin Zheng's avatar
Lianmin Zheng committed
464
465
466
            if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
467
                continue
468
            if name.startswith("model.vision_tower") and name not in params_dict:
469
                continue
470
471
            if self.config.tie_word_embeddings and "lm_head.weight" in name:
                continue
472
473
474
475
476
            # Handle FP8 kv-scale remapping
            if "scale" in name:
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue
477

Lianmin Zheng's avatar
Lianmin Zheng committed
478
479
480
481
482
483
484
485
486
487
488
489
490
491
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
492
                    continue
493
494
495
                # Skip loading kv_scale from ckpts towards new design.
                if name.endswith(".kv_scale") and name not in params_dict:
                    continue
496
497
498
499
500
501
502
503
                if name in params_dict.keys():
                    param = params_dict[name]
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
                    weight_loader(param, loaded_weight)
                else:
                    logger.warning(f"Parameter {name} not found in params_dict")
Cody Yu's avatar
Cody Yu committed
504

505
506
507
508
509
510
511
512
513
    def get_weights_by_name(
        self, name: str, truncate_size: int = 100, tp_size: int = 1
    ) -> Optional[torch.Tensor]:
        """Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face.

        Only used for unit test with an unoptimized performance.
        For optimized performance, please use torch.save and torch.load.
        """
        try:
514
515
516
517
518
519
520
521
522
523
524
            if name == "lm_head.weight" and self.config.tie_word_embeddings:
                logger.info(
                    "word embedding is tied for this model, return embed_tokens.weight as lm_head.weight."
                )
                return (
                    self.model.embed_tokens.weight.cpu()
                    .to(torch.float32)
                    .numpy()
                    .tolist()[:truncate_size]
                )

525
526
527
528
529
530
531
532
            mapped_name = name
            mapped_shard_id = None
            for param_name, weight_name, shard_id in self.stacked_params_mapping:
                if weight_name in name:
                    mapped_name = name.replace(weight_name, param_name)
                    mapped_shard_id = shard_id
                    break
            params_dict = dict(self.named_parameters())
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
            param = params_dict[mapped_name]
            if mapped_shard_id is not None:
                if mapped_shard_id in ["q", "k", "v"]:
                    num_heads = self.config.num_attention_heads // tp_size
                    num_kv_heads = self.config.num_key_value_heads // tp_size
                    head_dim = (
                        self.config.hidden_size // self.config.num_attention_heads
                    )
                    if mapped_shard_id == "q":
                        offset = 0
                        size = num_heads * head_dim
                    elif mapped_shard_id == "k":
                        offset = num_heads * head_dim
                        size = num_kv_heads * head_dim
                    elif mapped_shard_id == "v":
                        offset = (num_heads + num_kv_heads) * head_dim
                        size = num_kv_heads * head_dim
                    weight = param.data.narrow(0, offset, size)
                elif mapped_shard_id in [0, 1]:
                    intermediate_size = self.config.intermediate_size
                    slice_size = intermediate_size // tp_size
                    if mapped_shard_id == 0:  # gate_proj
                        offset = 0
                        size = slice_size
                    elif mapped_shard_id == 1:  # up_proj
                        offset = slice_size
                        size = slice_size

                    weight = param.data.narrow(0, offset, size)
562
563
564
                else:
                    weight = param.data
            else:
565
566
567
568
569
570
571
572
                weight = param.data
            if tp_size > 1 and ("o_proj" in name or "down_proj" in name):
                gathered_weights = [torch.zeros_like(weight) for _ in range(tp_size)]
                torch.distributed.all_gather(gathered_weights, weight)
                weight = torch.cat(gathered_weights, dim=1)
            return weight.cpu().to(torch.float32).numpy().tolist()[:truncate_size]

        except Exception:
573
            logger.error(
574
                f"Error getting weights by name {name} in LlamaForCausalLM: {get_exception_traceback()}"
575
576
577
            )
            return None

578
579
580
581
582
583
584
585
586
587
588
    def get_embed_and_head(self):
        return self.model.embed_tokens.weight, self.lm_head.weight

    def set_embed_and_head(self, embed, head):
        del self.model.embed_tokens.weight
        del self.lm_head.weight
        self.model.embed_tokens.weight = embed
        self.lm_head.weight = head
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

bjmsong's avatar
bjmsong committed
589
590
591
    def load_kv_cache_scales(self, quantization_param_path: str) -> None:
        self.model.load_kv_cache_scales(quantization_param_path)

Lianmin Zheng's avatar
Lianmin Zheng committed
592

593
594
595
596
class Phi3ForCausalLM(LlamaForCausalLM):
    pass


597
598
599
600
601
class InternLM3ForCausalLM(LlamaForCausalLM):
    pass


EntryClass = [LlamaForCausalLM, Phi3ForCausalLM, InternLM3ForCausalLM]