llama.py 21.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14

Lianmin Zheng's avatar
Lianmin Zheng committed
15
# Adapted from
16
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1
Lianmin Zheng's avatar
Lianmin Zheng committed
17
"""Inference-only LLaMA model compatible with HuggingFace weights."""
18

19
import logging
20
from typing import Any, Dict, Iterable, Optional, Tuple
Lianmin Zheng's avatar
Lianmin Zheng committed
21
22
23
24

import torch
from torch import nn
from transformers import LlamaConfig
bjmsong's avatar
bjmsong committed
25
26
27
28
from vllm.distributed import (
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
29
from vllm.model_executor.layers.rotary_embedding import get_rope
bjmsong's avatar
bjmsong committed
30
from vllm.model_executor.model_loader.weight_utils import kv_cache_scales_loader
Lianmin Zheng's avatar
Lianmin Zheng committed
31

32
33
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
34
35
36
37
38
from sglang.srt.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
Liangsheng Yin's avatar
Liangsheng Yin committed
39
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
40
from sglang.srt.layers.pooler import Pooler, PoolingType
41
from sglang.srt.layers.quantization.base_config import QuantizationConfig
Liangsheng Yin's avatar
Liangsheng Yin committed
42
from sglang.srt.layers.radix_attention import RadixAttention
43
44
45
46
from sglang.srt.layers.vocab_parallel_embedding import (
    ParallelLMHead,
    VocabParallelEmbedding,
)
47
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
48
from sglang.srt.model_loader.weight_utils import default_weight_loader
49
from sglang.srt.utils import make_layers
50
from sglang.utils import get_exception_traceback
Liangsheng Yin's avatar
Liangsheng Yin committed
51

52
53
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
54
55
56
57
58
59
60

class LlamaMLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
61
        quant_config: Optional[QuantizationConfig] = None,
Yineng Zhang's avatar
Yineng Zhang committed
62
        prefix: str = "",
Lianmin Zheng's avatar
Lianmin Zheng committed
63
64
65
66
67
68
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size,
            [intermediate_size] * 2,
            bias=False,
69
            quant_config=quant_config,
Yineng Zhang's avatar
Yineng Zhang committed
70
            prefix=f"{prefix}.gate_up_proj",
Lianmin Zheng's avatar
Lianmin Zheng committed
71
72
        )
        self.down_proj = RowParallelLinear(
Yuanhan Zhang's avatar
Yuanhan Zhang committed
73
74
75
76
            intermediate_size,
            hidden_size,
            bias=False,
            quant_config=quant_config,
Yineng Zhang's avatar
Yineng Zhang committed
77
            prefix=f"{prefix}.down_proj",
Lianmin Zheng's avatar
Lianmin Zheng committed
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
        )
        if hidden_act != "silu":
            raise ValueError(
                f"Unsupported activation: {hidden_act}. "
                "Only silu is supported for now."
            )
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class LlamaAttention(nn.Module):
    def __init__(
        self,
Ke Bao's avatar
Ke Bao committed
96
        config: LlamaConfig,
Lianmin Zheng's avatar
Lianmin Zheng committed
97
98
99
100
101
102
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        layer_id: int = 0,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
103
        rope_is_neox_style: bool = True,
Lianmin Zheng's avatar
Lianmin Zheng committed
104
        max_position_embeddings: int = 8192,
105
        quant_config: Optional[QuantizationConfig] = None,
Yineng Zhang's avatar
Yineng Zhang committed
106
        prefix: str = "",
107
        bias: bool = False,
Lianmin Zheng's avatar
Lianmin Zheng committed
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
Ke Bao's avatar
Ke Bao committed
125
126
127
128
        # MistralConfig has an optional head_dim introduced by Mistral-Nemo
        self.head_dim = getattr(
            config, "head_dim", self.hidden_size // self.total_num_heads
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
129
130
131
132
133
134
135
136
137
138
139
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
140
            bias=bias,
141
            quant_config=quant_config,
Yineng Zhang's avatar
Yineng Zhang committed
142
            prefix=f"{prefix}.qkv_proj",
Lianmin Zheng's avatar
Lianmin Zheng committed
143
144
145
146
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
147
            bias=bias,
148
            quant_config=quant_config,
Yineng Zhang's avatar
Yineng Zhang committed
149
            prefix=f"{prefix}.o_proj",
Lianmin Zheng's avatar
Lianmin Zheng committed
150
151
152
153
154
155
156
157
        )

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
158
            is_neox_style=rope_is_neox_style,
Lianmin Zheng's avatar
Lianmin Zheng committed
159
160
161
162
163
164
165
166
167
168
169
170
171
        )
        self.attn = RadixAttention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            layer_id=layer_id,
        )

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
172
        forward_batch: ForwardBatch,
Lianmin Zheng's avatar
Lianmin Zheng committed
173
174
175
176
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
177
        attn_output = self.attn(q, k, v, forward_batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
178
179
180
181
182
183
184
185
186
        output, _ = self.o_proj(attn_output)
        return output


class LlamaDecoderLayer(nn.Module):
    def __init__(
        self,
        config: LlamaConfig,
        layer_id: int = 0,
187
        quant_config: Optional[QuantizationConfig] = None,
Yineng Zhang's avatar
Yineng Zhang committed
188
        prefix: str = "",
Lianmin Zheng's avatar
Lianmin Zheng committed
189
190
191
192
193
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
194
        if rope_scaling is not None and getattr(
195
196
            config, "original_max_position_embeddings", None
        ):
zhyncs's avatar
zhyncs committed
197
198
199
            rope_scaling["original_max_position_embeddings"] = (
                config.original_max_position_embeddings
            )
200
        rope_is_neox_style = getattr(config, "rope_is_neox_style", True)
Lianmin Zheng's avatar
Lianmin Zheng committed
201
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
202
203
204
205
206
        # Support llamafy/Qwen-Qwen2.5-7B-Instruct-llamafied with attention_bias
        # Support internlm/internlm-7b with bias
        attention_bias = getattr(config, "attention_bias", False) or getattr(
            config, "bias", False
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
207
        self.self_attn = LlamaAttention(
Ke Bao's avatar
Ke Bao committed
208
            config=config,
Lianmin Zheng's avatar
Lianmin Zheng committed
209
210
211
212
213
214
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            layer_id=layer_id,
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
215
            rope_is_neox_style=rope_is_neox_style,
Lianmin Zheng's avatar
Lianmin Zheng committed
216
            max_position_embeddings=max_position_embeddings,
217
            quant_config=quant_config,
Yineng Zhang's avatar
Yineng Zhang committed
218
            prefix=f"{prefix}.self_attn",
219
            bias=attention_bias,
Lianmin Zheng's avatar
Lianmin Zheng committed
220
221
222
223
224
        )
        self.mlp = LlamaMLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
225
            quant_config=quant_config,
Yineng Zhang's avatar
Yineng Zhang committed
226
            prefix=f"{prefix}.mlp",
Lianmin Zheng's avatar
Lianmin Zheng committed
227
228
229
230
231
232
233
234
235
236
        )
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
237
        forward_batch: ForwardBatch,
Lianmin Zheng's avatar
Lianmin Zheng committed
238
239
240
241
242
243
244
245
246
247
248
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
249
            forward_batch=forward_batch,
Lianmin Zheng's avatar
Lianmin Zheng committed
250
251
252
253
254
255
256
257
258
259
260
261
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


class LlamaModel(nn.Module):
    def __init__(
        self,
        config: LlamaConfig,
262
        quant_config: Optional[QuantizationConfig] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
263
264
265
266
267
268
269
270
    ) -> None:
        super().__init__()
        self.config = config
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
271
            quant_config=quant_config,
Lianmin Zheng's avatar
Lianmin Zheng committed
272
        )
273
274
275
276
277
278
        self.layers = make_layers(
            config.num_hidden_layers,
            lambda idx, prefix: LlamaDecoderLayer(
                config=config, quant_config=quant_config, layer_id=idx, prefix=prefix
            ),
            prefix="model.layers",
Lianmin Zheng's avatar
Lianmin Zheng committed
279
        )
280

Lianmin Zheng's avatar
Lianmin Zheng committed
281
282
283
284
285
286
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
287
        forward_batch: ForwardBatch,
288
        input_embeds: torch.Tensor = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
289
    ) -> torch.Tensor:
290
        if input_embeds is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
291
292
            hidden_states = self.embed_tokens(input_ids)
        else:
293
            hidden_states = input_embeds
Lianmin Zheng's avatar
Lianmin Zheng committed
294
295
296
297
298
299
        residual = None
        for i in range(len(self.layers)):
            layer = self.layers[i]
            hidden_states, residual = layer(
                positions,
                hidden_states,
300
                forward_batch,
Lianmin Zheng's avatar
Lianmin Zheng committed
301
302
303
304
305
                residual,
            )
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states

bjmsong's avatar
bjmsong committed
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
    # If this function is called, it should always initialize KV cache scale
    # factors (or else raise an exception). Thus, handled exceptions should
    # make sure to leave KV cache scale factors in a known good (dummy) state
    def load_kv_cache_scales(self, quantization_param_path: str) -> None:
        tp_size = get_tensor_model_parallel_world_size()
        tp_rank = get_tensor_model_parallel_rank()
        for layer_idx, scaling_factor in kv_cache_scales_loader(
            quantization_param_path,
            tp_rank,
            tp_size,
            self.config.num_hidden_layers,
            self.config.__class__.model_type,
        ):
            if not isinstance(self.layers[layer_idx], nn.Identity):
                layer_self_attn = self.layers[layer_idx].self_attn

            if hasattr(layer_self_attn.attn, "k_scale"):
                layer_self_attn.attn.k_scale = scaling_factor
                layer_self_attn.attn.v_scale = scaling_factor
            else:
                raise RuntimeError(
                    "Self attention has no KV cache scaling " "factor attribute!"
                )

Lianmin Zheng's avatar
Lianmin Zheng committed
330
331

class LlamaForCausalLM(nn.Module):
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353

    # BitandBytes specific attributes
    default_bitsandbytes_target_modules = [
        ".gate_proj.",
        ".down_proj.",
        ".up_proj.",
        ".q_proj.",
        ".k_proj.",
        ".v_proj.",
        ".o_proj.",
    ]
    # in TP, these weights are partitioned along the column dimension (dim=-1)
    column_parallel_weights_modules = [".down_proj.", ".o_proj."]
    bitsandbytes_stacked_params_mapping = {
        # shard_name, weight_name, index
        "q_proj": ("qkv_proj", 0),
        "k_proj": ("qkv_proj", 1),
        "v_proj": ("qkv_proj", 2),
        "gate_proj": ("gate_up_proj", 0),
        "up_proj": ("gate_up_proj", 1),
    }

Lianmin Zheng's avatar
Lianmin Zheng committed
354
355
356
    def __init__(
        self,
        config: LlamaConfig,
357
        quant_config: Optional[QuantizationConfig] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
358
359
360
    ) -> None:
        super().__init__()
        self.config = config
361
362
        self.quant_config = quant_config
        self.model = LlamaModel(config, quant_config=quant_config)
fzyzcjy's avatar
fzyzcjy committed
363
364
        # Llama 3.2 1B Instruct set tie_word_embeddings to True
        # Llama 3.1 8B Instruct set tie_word_embeddings to False
365
366
367
368
369
370
        if self.config.tie_word_embeddings:
            self.lm_head = self.model.embed_tokens
        else:
            self.lm_head = ParallelLMHead(
                config.vocab_size, config.hidden_size, quant_config=quant_config
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
371
        self.logits_processor = LogitsProcessor(config)
372
        self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
373
374
375
376
377
378
379
380
        self.stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
        ]
Lianmin Zheng's avatar
Lianmin Zheng committed
381

Liangsheng Yin's avatar
Liangsheng Yin committed
382
    @torch.no_grad()
Lianmin Zheng's avatar
Lianmin Zheng committed
383
384
385
386
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
387
        forward_batch: ForwardBatch,
388
        input_embeds: torch.Tensor = None,
389
        get_embedding: bool = False,
Liangsheng Yin's avatar
Liangsheng Yin committed
390
    ) -> LogitsProcessorOutput:
391
        hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
392
393
        if not get_embedding:
            return self.logits_processor(
394
                input_ids, hidden_states, self.lm_head, forward_batch
395
396
397
            )
        else:
            return self.pooler(hidden_states, forward_batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
398

399
    def get_hidden_dim(self, module_name):
400
        # return input_dim, output_dim
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
        if module_name in ["q_proj", "o_proj", "qkv_proj"]:
            return self.config.hidden_size, self.config.hidden_size
        elif module_name in ["kv_proj"]:
            return self.config.hidden_size, self.config.hidden_size // (
                self.config.num_attention_heads // self.config.num_key_value_heads
            )
        elif module_name == "gate_up_proj":
            return self.config.hidden_size, self.config.intermediate_size
        elif module_name == "down_proj":
            return self.config.intermediate_size, self.config.hidden_size
        else:
            raise NotImplementedError()

    def get_module_name(self, name):
        params_mapping = {
            "q_proj": "qkv_proj",
            "k_proj": "qkv_proj",
            "v_proj": "qkv_proj",
            "gate_proj": "gate_up_proj",
            "up_proj": "gate_up_proj",
        }
        return params_mapping.get(name, name)

    def get_module_name_from_weight_name(self, name):
425
        for param_name, weight_name, shard_id, num_shard in self.stacked_params_mapping:
426
427
428
429
430
431
432
433
434
435
436
            if weight_name in name:
                return (
                    name.replace(weight_name, param_name)[: -len(".weight")],
                    num_shard,
                )
        return name[: -len(".weight")], 1

    def get_num_params(self):
        params_dict = dict(self.named_parameters())
        return len(params_dict)

437
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
Lianmin Zheng's avatar
Lianmin Zheng committed
438
439
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
440
441
442
443
444
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
Lianmin Zheng's avatar
Lianmin Zheng committed
445
        ]
446

447
        params_dict = dict(self.named_parameters())
448

449
        for name, loaded_weight in weights:
Lianmin Zheng's avatar
Lianmin Zheng committed
450
            if "rotary_emb.inv_freq" in name or "projector" in name:
451
                continue
Lianmin Zheng's avatar
Lianmin Zheng committed
452
453
454
            if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
455
                continue
456
            if name.startswith("model.vision_tower") and name not in params_dict:
457
                continue
458

Lianmin Zheng's avatar
Lianmin Zheng committed
459
460
461
462
463
464
465
466
467
468
469
470
471
472
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
473
                    continue
474
475
476
                # Skip loading kv_scale from ckpts towards new design.
                if name.endswith(".kv_scale") and name not in params_dict:
                    continue
Lianmin Zheng's avatar
Lianmin Zheng committed
477
478
479
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(param, loaded_weight)
Cody Yu's avatar
Cody Yu committed
480

481
482
483
484
485
486
487
488
489
    def get_weights_by_name(
        self, name: str, truncate_size: int = 100, tp_size: int = 1
    ) -> Optional[torch.Tensor]:
        """Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face.

        Only used for unit test with an unoptimized performance.
        For optimized performance, please use torch.save and torch.load.
        """
        try:
490
491
492
493
494
495
496
497
498
499
500
            if name == "lm_head.weight" and self.config.tie_word_embeddings:
                logger.info(
                    "word embedding is tied for this model, return embed_tokens.weight as lm_head.weight."
                )
                return (
                    self.model.embed_tokens.weight.cpu()
                    .to(torch.float32)
                    .numpy()
                    .tolist()[:truncate_size]
                )

501
502
503
504
505
506
507
508
            mapped_name = name
            mapped_shard_id = None
            for param_name, weight_name, shard_id in self.stacked_params_mapping:
                if weight_name in name:
                    mapped_name = name.replace(weight_name, param_name)
                    mapped_shard_id = shard_id
                    break
            params_dict = dict(self.named_parameters())
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
            param = params_dict[mapped_name]
            if mapped_shard_id is not None:
                if mapped_shard_id in ["q", "k", "v"]:
                    num_heads = self.config.num_attention_heads // tp_size
                    num_kv_heads = self.config.num_key_value_heads // tp_size
                    head_dim = (
                        self.config.hidden_size // self.config.num_attention_heads
                    )
                    if mapped_shard_id == "q":
                        offset = 0
                        size = num_heads * head_dim
                    elif mapped_shard_id == "k":
                        offset = num_heads * head_dim
                        size = num_kv_heads * head_dim
                    elif mapped_shard_id == "v":
                        offset = (num_heads + num_kv_heads) * head_dim
                        size = num_kv_heads * head_dim
                    weight = param.data.narrow(0, offset, size)
                elif mapped_shard_id in [0, 1]:
                    intermediate_size = self.config.intermediate_size
                    slice_size = intermediate_size // tp_size
                    if mapped_shard_id == 0:  # gate_proj
                        offset = 0
                        size = slice_size
                    elif mapped_shard_id == 1:  # up_proj
                        offset = slice_size
                        size = slice_size

                    weight = param.data.narrow(0, offset, size)
538
539
540
                else:
                    weight = param.data
            else:
541
542
543
544
545
546
547
548
                weight = param.data
            if tp_size > 1 and ("o_proj" in name or "down_proj" in name):
                gathered_weights = [torch.zeros_like(weight) for _ in range(tp_size)]
                torch.distributed.all_gather(gathered_weights, weight)
                weight = torch.cat(gathered_weights, dim=1)
            return weight.cpu().to(torch.float32).numpy().tolist()[:truncate_size]

        except Exception:
549
            logger.error(
550
                f"Error getting weights by name {name} in LlamaForCausalLM: {get_exception_traceback()}"
551
552
553
            )
            return None

554
555
556
557
558
559
560
561
562
563
564
    def get_embed_and_head(self):
        return self.model.embed_tokens.weight, self.lm_head.weight

    def set_embed_and_head(self, embed, head):
        del self.model.embed_tokens.weight
        del self.lm_head.weight
        self.model.embed_tokens.weight = embed
        self.lm_head.weight = head
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

bjmsong's avatar
bjmsong committed
565
566
567
    def load_kv_cache_scales(self, quantization_param_path: str) -> None:
        self.model.load_kv_cache_scales(quantization_param_path)

Lianmin Zheng's avatar
Lianmin Zheng committed
568

569
570
571
572
573
class Phi3ForCausalLM(LlamaForCausalLM):
    pass


EntryClass = [LlamaForCausalLM, Phi3ForCausalLM]