"cacheflow/vscode:/vscode.git/clone" did not exist on "64e0e383148a613c327d4bf9e866b7a185df8277"
llama.py 19.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14

Lianmin Zheng's avatar
Lianmin Zheng committed
15
# Adapted from
16
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1
Lianmin Zheng's avatar
Lianmin Zheng committed
17
"""Inference-only LLaMA model compatible with HuggingFace weights."""
18

19
import logging
20
from typing import Any, Dict, Iterable, Optional, Tuple
Lianmin Zheng's avatar
Lianmin Zheng committed
21
22
23
24

import torch
from torch import nn
from transformers import LlamaConfig
25
from vllm.distributed import get_tensor_model_parallel_world_size
Lianmin Zheng's avatar
Lianmin Zheng committed
26
from vllm.model_executor.layers.rotary_embedding import get_rope
27
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
Lianmin Zheng's avatar
Lianmin Zheng committed
28

29
30
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
31
32
33
34
35
from sglang.srt.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
Liangsheng Yin's avatar
Liangsheng Yin committed
36
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
37
from sglang.srt.layers.pooler import Pooler, PoolingType
38
from sglang.srt.layers.quantization.base_config import QuantizationConfig
Liangsheng Yin's avatar
Liangsheng Yin committed
39
from sglang.srt.layers.radix_attention import RadixAttention
40
from sglang.srt.layers.torchao_utils import apply_torchao_config_
41
42
43
44
from sglang.srt.layers.vocab_parallel_embedding import (
    ParallelLMHead,
    VocabParallelEmbedding,
)
45
from sglang.srt.managers.schedule_batch import global_server_args_dict
46
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
47
from sglang.srt.utils import make_layers
Liangsheng Yin's avatar
Liangsheng Yin committed
48

49
50
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
51
52
53
54
55
56
57

class LlamaMLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
58
        quant_config: Optional[QuantizationConfig] = None,
Yineng Zhang's avatar
Yineng Zhang committed
59
        prefix: str = "",
Lianmin Zheng's avatar
Lianmin Zheng committed
60
61
62
63
64
65
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size,
            [intermediate_size] * 2,
            bias=False,
66
            quant_config=quant_config,
Yineng Zhang's avatar
Yineng Zhang committed
67
            prefix=f"{prefix}.gate_up_proj",
Lianmin Zheng's avatar
Lianmin Zheng committed
68
69
        )
        self.down_proj = RowParallelLinear(
Yuanhan Zhang's avatar
Yuanhan Zhang committed
70
71
72
73
            intermediate_size,
            hidden_size,
            bias=False,
            quant_config=quant_config,
Yineng Zhang's avatar
Yineng Zhang committed
74
            prefix=f"{prefix}.down_proj",
Lianmin Zheng's avatar
Lianmin Zheng committed
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
        )
        if hidden_act != "silu":
            raise ValueError(
                f"Unsupported activation: {hidden_act}. "
                "Only silu is supported for now."
            )
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class LlamaAttention(nn.Module):
    def __init__(
        self,
Ke Bao's avatar
Ke Bao committed
93
        config: LlamaConfig,
Lianmin Zheng's avatar
Lianmin Zheng committed
94
95
96
97
98
99
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        layer_id: int = 0,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
100
        rope_is_neox_style: bool = True,
Lianmin Zheng's avatar
Lianmin Zheng committed
101
        max_position_embeddings: int = 8192,
102
        quant_config: Optional[QuantizationConfig] = None,
Yineng Zhang's avatar
Yineng Zhang committed
103
        prefix: str = "",
Lianmin Zheng's avatar
Lianmin Zheng committed
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
Ke Bao's avatar
Ke Bao committed
121
122
123
124
        # MistralConfig has an optional head_dim introduced by Mistral-Nemo
        self.head_dim = getattr(
            config, "head_dim", self.hidden_size // self.total_num_heads
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
125
126
127
128
129
130
131
132
133
134
135
136
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
137
            quant_config=quant_config,
Yineng Zhang's avatar
Yineng Zhang committed
138
            prefix=f"{prefix}.qkv_proj",
Lianmin Zheng's avatar
Lianmin Zheng committed
139
140
141
142
143
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
144
            quant_config=quant_config,
Yineng Zhang's avatar
Yineng Zhang committed
145
            prefix=f"{prefix}.o_proj",
Lianmin Zheng's avatar
Lianmin Zheng committed
146
147
148
149
150
151
152
153
        )

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
154
            is_neox_style=rope_is_neox_style,
Lianmin Zheng's avatar
Lianmin Zheng committed
155
156
157
158
159
160
161
162
163
164
165
166
167
        )
        self.attn = RadixAttention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            layer_id=layer_id,
        )

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
168
        forward_batch: ForwardBatch,
Lianmin Zheng's avatar
Lianmin Zheng committed
169
170
171
172
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
173
        attn_output = self.attn(q, k, v, forward_batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
174
175
176
177
178
179
180
181
182
        output, _ = self.o_proj(attn_output)
        return output


class LlamaDecoderLayer(nn.Module):
    def __init__(
        self,
        config: LlamaConfig,
        layer_id: int = 0,
183
        quant_config: Optional[QuantizationConfig] = None,
Yineng Zhang's avatar
Yineng Zhang committed
184
        prefix: str = "",
Lianmin Zheng's avatar
Lianmin Zheng committed
185
186
187
188
189
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
190
        if rope_scaling is not None and getattr(
191
192
            config, "original_max_position_embeddings", None
        ):
zhyncs's avatar
zhyncs committed
193
194
195
            rope_scaling["original_max_position_embeddings"] = (
                config.original_max_position_embeddings
            )
196
        rope_is_neox_style = getattr(config, "rope_is_neox_style", True)
Lianmin Zheng's avatar
Lianmin Zheng committed
197
198
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
        self.self_attn = LlamaAttention(
Ke Bao's avatar
Ke Bao committed
199
            config=config,
Lianmin Zheng's avatar
Lianmin Zheng committed
200
201
202
203
204
205
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            layer_id=layer_id,
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
206
            rope_is_neox_style=rope_is_neox_style,
Lianmin Zheng's avatar
Lianmin Zheng committed
207
            max_position_embeddings=max_position_embeddings,
208
            quant_config=quant_config,
Yineng Zhang's avatar
Yineng Zhang committed
209
            prefix=f"{prefix}.self_attn",
Lianmin Zheng's avatar
Lianmin Zheng committed
210
211
212
213
214
        )
        self.mlp = LlamaMLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
215
            quant_config=quant_config,
Yineng Zhang's avatar
Yineng Zhang committed
216
            prefix=f"{prefix}.mlp",
Lianmin Zheng's avatar
Lianmin Zheng committed
217
218
219
220
221
222
223
224
225
226
        )
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
227
        forward_batch: ForwardBatch,
Lianmin Zheng's avatar
Lianmin Zheng committed
228
229
230
231
232
233
234
235
236
237
238
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
239
            forward_batch=forward_batch,
Lianmin Zheng's avatar
Lianmin Zheng committed
240
241
242
243
244
245
246
247
248
249
250
251
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


class LlamaModel(nn.Module):
    def __init__(
        self,
        config: LlamaConfig,
252
        quant_config: Optional[QuantizationConfig] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
253
254
255
256
257
258
259
260
261
    ) -> None:
        super().__init__()
        self.config = config
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
262
263
264
265
266
267
        self.layers = make_layers(
            config.num_hidden_layers,
            lambda idx, prefix: LlamaDecoderLayer(
                config=config, quant_config=quant_config, layer_id=idx, prefix=prefix
            ),
            prefix="model.layers",
Lianmin Zheng's avatar
Lianmin Zheng committed
268
        )
269

Lianmin Zheng's avatar
Lianmin Zheng committed
270
271
272
273
274
275
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
276
        forward_batch: ForwardBatch,
277
        input_embeds: torch.Tensor = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
278
    ) -> torch.Tensor:
279
        if input_embeds is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
280
281
            hidden_states = self.embed_tokens(input_ids)
        else:
282
            hidden_states = input_embeds
Lianmin Zheng's avatar
Lianmin Zheng committed
283
284
285
286
287
288
        residual = None
        for i in range(len(self.layers)):
            layer = self.layers[i]
            hidden_states, residual = layer(
                positions,
                hidden_states,
289
                forward_batch,
Lianmin Zheng's avatar
Lianmin Zheng committed
290
291
292
293
294
295
296
297
298
299
                residual,
            )
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states


class LlamaForCausalLM(nn.Module):
    def __init__(
        self,
        config: LlamaConfig,
300
        quant_config: Optional[QuantizationConfig] = None,
301
        cache_config=None,
Lianmin Zheng's avatar
Lianmin Zheng committed
302
303
304
    ) -> None:
        super().__init__()
        self.config = config
305
        self.quant_config = quant_config
306
        self.torchao_config = global_server_args_dict["torchao_config"]
307
        self.model = LlamaModel(config, quant_config=quant_config)
308
        self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
Lianmin Zheng's avatar
Lianmin Zheng committed
309
        self.logits_processor = LogitsProcessor(config)
310
        self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
311
312
313
314
315
316
317
318
        self.stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
        ]
Lianmin Zheng's avatar
Lianmin Zheng committed
319

Liangsheng Yin's avatar
Liangsheng Yin committed
320
    @torch.no_grad()
Lianmin Zheng's avatar
Lianmin Zheng committed
321
322
323
324
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
325
        forward_batch: ForwardBatch,
326
        input_embeds: torch.Tensor = None,
327
        get_embedding: bool = False,
Liangsheng Yin's avatar
Liangsheng Yin committed
328
    ) -> LogitsProcessorOutput:
329
        hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
330
331
        if not get_embedding:
            return self.logits_processor(
332
                input_ids, hidden_states, self.lm_head.weight, forward_batch
333
334
335
            )
        else:
            return self.pooler(hidden_states, forward_batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
336

337
    def get_hidden_dim(self, module_name):
338
        # return input_dim, output_dim
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
        if module_name in ["q_proj", "o_proj", "qkv_proj"]:
            return self.config.hidden_size, self.config.hidden_size
        elif module_name in ["kv_proj"]:
            return self.config.hidden_size, self.config.hidden_size // (
                self.config.num_attention_heads // self.config.num_key_value_heads
            )
        elif module_name == "gate_up_proj":
            return self.config.hidden_size, self.config.intermediate_size
        elif module_name == "down_proj":
            return self.config.intermediate_size, self.config.hidden_size
        else:
            raise NotImplementedError()

    def get_module_name(self, name):
        params_mapping = {
            "q_proj": "qkv_proj",
            "k_proj": "qkv_proj",
            "v_proj": "qkv_proj",
            "gate_proj": "gate_up_proj",
            "up_proj": "gate_up_proj",
        }
        return params_mapping.get(name, name)

    def get_module_name_from_weight_name(self, name):
363
        for param_name, weight_name, shard_id, num_shard in self.stacked_params_mapping:
364
365
366
367
368
369
370
371
372
373
374
            if weight_name in name:
                return (
                    name.replace(weight_name, param_name)[: -len(".weight")],
                    num_shard,
                )
        return name[: -len(".weight")], 1

    def get_num_params(self):
        params_dict = dict(self.named_parameters())
        return len(params_dict)

375
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
376
        embed_tokens_weight = None
Lianmin Zheng's avatar
Lianmin Zheng committed
377
378
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
379
380
381
382
383
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
Lianmin Zheng's avatar
Lianmin Zheng committed
384
        ]
385

386
        params_dict = dict(self.named_parameters())
387

388
389
390
391
392
393
        load_tie_word_embeddings = (
            hasattr(self.config, "tie_word_embeddings")
            and self.config.tie_word_embeddings
            and "lm_head.weight" in params_dict
        )

394
        for name, loaded_weight in weights:
Lianmin Zheng's avatar
Lianmin Zheng committed
395
            if "rotary_emb.inv_freq" in name or "projector" in name:
396
                continue
Lianmin Zheng's avatar
Lianmin Zheng committed
397
398
399
            if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
400
                continue
401
            if name.startswith("model.vision_tower") and name not in params_dict:
402
                continue
403

Lianmin Zheng's avatar
Lianmin Zheng committed
404
405
406
407
408
409
410
411
412
413
414
415
416
417
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
418
                    continue
419
420
421
                # Skip loading kv_scale from ckpts towards new design.
                if name.endswith(".kv_scale") and name not in params_dict:
                    continue
Lianmin Zheng's avatar
Lianmin Zheng committed
422
423
424
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(param, loaded_weight)
Cody Yu's avatar
Cody Yu committed
425

426
427
428
429
430
431
432
433
434
435
                if load_tie_word_embeddings and name == "model.embed_tokens.weight":
                    embed_tokens_weight = loaded_weight

        if load_tie_word_embeddings:
            # Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing
            param = self.lm_head.weight
            weight_loader = getattr(param, "weight_loader", default_weight_loader)
            if embed_tokens_weight is not None:
                weight_loader(param, embed_tokens_weight)

436
        apply_torchao_config_(self, params_dict, set(["proj.weight"]))
437

438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
    def get_weights_by_name(
        self, name: str, truncate_size: int = 100, tp_size: int = 1
    ) -> Optional[torch.Tensor]:
        """Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face.

        Only used for unit test with an unoptimized performance.
        For optimized performance, please use torch.save and torch.load.
        """
        try:
            mapped_name = name
            mapped_shard_id = None
            for param_name, weight_name, shard_id in self.stacked_params_mapping:
                if weight_name in name:
                    mapped_name = name.replace(weight_name, param_name)
                    mapped_shard_id = shard_id
                    break
            params_dict = dict(self.named_parameters())
            if mapped_name in params_dict:
                param = params_dict[mapped_name]
                if mapped_shard_id is not None:
                    if mapped_shard_id in ["q", "k", "v"]:
                        num_heads = self.config.num_attention_heads // tp_size
                        num_kv_heads = self.config.num_key_value_heads // tp_size
                        head_dim = (
                            self.config.hidden_size // self.config.num_attention_heads
                        )
                        if mapped_shard_id == "q":
                            offset = 0
                            size = num_heads * head_dim
                        elif mapped_shard_id == "k":
                            offset = num_heads * head_dim
                            size = num_kv_heads * head_dim
                        elif mapped_shard_id == "v":
                            offset = (num_heads + num_kv_heads) * head_dim
                            size = num_kv_heads * head_dim
                        weight = param.data.narrow(0, offset, size)
                    elif mapped_shard_id in [0, 1]:
                        intermediate_size = self.config.intermediate_size
                        hidden_size = self.config.hidden_size
                        slice_size = intermediate_size // tp_size
                        if mapped_shard_id == 0:  # gate_proj
                            offset = 0
                            size = slice_size
                        elif mapped_shard_id == 1:  # up_proj
                            offset = slice_size
                            size = slice_size

                        weight = param.data.narrow(0, offset, size)
                    else:
                        weight = param.data
                else:
                    weight = param.data
                if tp_size > 1 and ("o_proj" in name or "down_proj" in name):
                    gathered_weights = [
                        torch.zeros_like(weight) for _ in range(tp_size)
                    ]
                    torch.distributed.all_gather(gathered_weights, weight)
                    weight = torch.cat(gathered_weights, dim=1)
                return weight.cpu().to(torch.float32).numpy().tolist()[:truncate_size]
            else:
                return None

        except Exception as e:
            logger.error(
                f"Error getting weights by name {name} in LlamaForCausalLM: {e}"
            )
            return None

Lianmin Zheng's avatar
Lianmin Zheng committed
506

507
508
509
510
511
class Phi3ForCausalLM(LlamaForCausalLM):
    pass


EntryClass = [LlamaForCausalLM, Phi3ForCausalLM]