llama.py 15.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
# Adapted from
17
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1
Lianmin Zheng's avatar
Lianmin Zheng committed
18
"""Inference-only LLaMA model compatible with HuggingFace weights."""
19

20
from typing import Any, Dict, Iterable, Optional, Tuple
Lianmin Zheng's avatar
Lianmin Zheng committed
21
22
23
24

import torch
from torch import nn
from transformers import LlamaConfig
25
from vllm.distributed import get_tensor_model_parallel_world_size
Lianmin Zheng's avatar
Lianmin Zheng committed
26
27
28
29
30
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
    ParallelLMHead,
    VocabParallelEmbedding,
)
31
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
Lianmin Zheng's avatar
Lianmin Zheng committed
32

33
34
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
35
36
37
38
39
from sglang.srt.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
Liangsheng Yin's avatar
Liangsheng Yin committed
40
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
41
from sglang.srt.layers.quantization.base_config import QuantizationConfig
Liangsheng Yin's avatar
Liangsheng Yin committed
42
from sglang.srt.layers.radix_attention import RadixAttention
43
from sglang.srt.layers.torchao_utils import apply_torchao_config_
44
from sglang.srt.managers.schedule_batch import global_server_args_dict
45
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
Liangsheng Yin's avatar
Liangsheng Yin committed
46

Lianmin Zheng's avatar
Lianmin Zheng committed
47
48
49
50
51
52
53

class LlamaMLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
54
        quant_config: Optional[QuantizationConfig] = None,
Yineng Zhang's avatar
Yineng Zhang committed
55
        prefix: str = "",
Lianmin Zheng's avatar
Lianmin Zheng committed
56
57
58
59
60
61
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size,
            [intermediate_size] * 2,
            bias=False,
62
            quant_config=quant_config,
Yineng Zhang's avatar
Yineng Zhang committed
63
            prefix=f"{prefix}.gate_up_proj",
Lianmin Zheng's avatar
Lianmin Zheng committed
64
65
        )
        self.down_proj = RowParallelLinear(
Yuanhan Zhang's avatar
Yuanhan Zhang committed
66
67
68
69
            intermediate_size,
            hidden_size,
            bias=False,
            quant_config=quant_config,
Yineng Zhang's avatar
Yineng Zhang committed
70
            prefix=f"{prefix}.down_proj",
Lianmin Zheng's avatar
Lianmin Zheng committed
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
        )
        if hidden_act != "silu":
            raise ValueError(
                f"Unsupported activation: {hidden_act}. "
                "Only silu is supported for now."
            )
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class LlamaAttention(nn.Module):
    def __init__(
        self,
Ke Bao's avatar
Ke Bao committed
89
        config: LlamaConfig,
Lianmin Zheng's avatar
Lianmin Zheng committed
90
91
92
93
94
95
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        layer_id: int = 0,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
96
        rope_is_neox_style: bool = True,
Lianmin Zheng's avatar
Lianmin Zheng committed
97
        max_position_embeddings: int = 8192,
98
        quant_config: Optional[QuantizationConfig] = None,
Yineng Zhang's avatar
Yineng Zhang committed
99
        prefix: str = "",
Lianmin Zheng's avatar
Lianmin Zheng committed
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
Ke Bao's avatar
Ke Bao committed
117
118
119
120
        # MistralConfig has an optional head_dim introduced by Mistral-Nemo
        self.head_dim = getattr(
            config, "head_dim", self.hidden_size // self.total_num_heads
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
121
122
123
124
125
126
127
128
129
130
131
132
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
133
            quant_config=quant_config,
Yineng Zhang's avatar
Yineng Zhang committed
134
            prefix=f"{prefix}.qkv_proj",
Lianmin Zheng's avatar
Lianmin Zheng committed
135
136
137
138
139
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
140
            quant_config=quant_config,
Yineng Zhang's avatar
Yineng Zhang committed
141
            prefix=f"{prefix}.o_proj",
Lianmin Zheng's avatar
Lianmin Zheng committed
142
143
144
145
146
147
148
149
        )

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
150
            is_neox_style=rope_is_neox_style,
Lianmin Zheng's avatar
Lianmin Zheng committed
151
152
153
154
155
156
157
158
159
160
161
162
163
        )
        self.attn = RadixAttention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            layer_id=layer_id,
        )

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
164
        forward_batch: ForwardBatch,
Lianmin Zheng's avatar
Lianmin Zheng committed
165
166
167
168
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
169
        attn_output = self.attn(q, k, v, forward_batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
170
171
172
173
174
175
176
177
178
        output, _ = self.o_proj(attn_output)
        return output


class LlamaDecoderLayer(nn.Module):
    def __init__(
        self,
        config: LlamaConfig,
        layer_id: int = 0,
179
        quant_config: Optional[QuantizationConfig] = None,
Yineng Zhang's avatar
Yineng Zhang committed
180
        prefix: str = "",
Lianmin Zheng's avatar
Lianmin Zheng committed
181
182
183
184
185
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
186
        if rope_scaling is not None and getattr(
187
188
            config, "original_max_position_embeddings", None
        ):
zhyncs's avatar
zhyncs committed
189
190
191
            rope_scaling["original_max_position_embeddings"] = (
                config.original_max_position_embeddings
            )
192
        rope_is_neox_style = getattr(config, "rope_is_neox_style", True)
Lianmin Zheng's avatar
Lianmin Zheng committed
193
194
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
        self.self_attn = LlamaAttention(
Ke Bao's avatar
Ke Bao committed
195
            config=config,
Lianmin Zheng's avatar
Lianmin Zheng committed
196
197
198
199
200
201
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            layer_id=layer_id,
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
202
            rope_is_neox_style=rope_is_neox_style,
Lianmin Zheng's avatar
Lianmin Zheng committed
203
            max_position_embeddings=max_position_embeddings,
204
            quant_config=quant_config,
Yineng Zhang's avatar
Yineng Zhang committed
205
            prefix=f"{prefix}.self_attn",
Lianmin Zheng's avatar
Lianmin Zheng committed
206
207
208
209
210
        )
        self.mlp = LlamaMLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
211
            quant_config=quant_config,
Yineng Zhang's avatar
Yineng Zhang committed
212
            prefix=f"{prefix}.mlp",
Lianmin Zheng's avatar
Lianmin Zheng committed
213
214
215
216
217
218
219
220
221
222
        )
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
223
        forward_batch: ForwardBatch,
Lianmin Zheng's avatar
Lianmin Zheng committed
224
225
226
227
228
229
230
231
232
233
234
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
235
            forward_batch=forward_batch,
Lianmin Zheng's avatar
Lianmin Zheng committed
236
237
238
239
240
241
242
243
244
245
246
247
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


class LlamaModel(nn.Module):
    def __init__(
        self,
        config: LlamaConfig,
248
        quant_config: Optional[QuantizationConfig] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
249
250
251
252
253
254
255
256
257
258
259
    ) -> None:
        super().__init__()
        self.config = config
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
        self.layers = nn.ModuleList(
            [
Yineng Zhang's avatar
Yineng Zhang committed
260
261
262
                LlamaDecoderLayer(
                    config, i, quant_config=quant_config, prefix=f"model.layers.{i}"
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
263
264
265
266
267
268
269
270
271
                for i in range(config.num_hidden_layers)
            ]
        )
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
272
        forward_batch: ForwardBatch,
273
        input_embeds: torch.Tensor = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
274
    ) -> torch.Tensor:
275
        if input_embeds is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
276
277
            hidden_states = self.embed_tokens(input_ids)
        else:
278
            hidden_states = input_embeds
Lianmin Zheng's avatar
Lianmin Zheng committed
279
280
281
282
283
284
        residual = None
        for i in range(len(self.layers)):
            layer = self.layers[i]
            hidden_states, residual = layer(
                positions,
                hidden_states,
285
                forward_batch,
Lianmin Zheng's avatar
Lianmin Zheng committed
286
287
288
289
290
291
292
293
294
295
                residual,
            )
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states


class LlamaForCausalLM(nn.Module):
    def __init__(
        self,
        config: LlamaConfig,
296
        quant_config: Optional[QuantizationConfig] = None,
297
        cache_config=None,
Lianmin Zheng's avatar
Lianmin Zheng committed
298
299
300
    ) -> None:
        super().__init__()
        self.config = config
301
        self.quant_config = quant_config
302
        self.torchao_config = global_server_args_dict["torchao_config"]
303
        self.model = LlamaModel(config, quant_config=quant_config)
Lianmin Zheng's avatar
Lianmin Zheng committed
304
305
306
        self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
        self.logits_processor = LogitsProcessor(config)

Liangsheng Yin's avatar
Liangsheng Yin committed
307
    @torch.no_grad()
Lianmin Zheng's avatar
Lianmin Zheng committed
308
309
310
311
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
312
        forward_batch: ForwardBatch,
313
        input_embeds: torch.Tensor = None,
Liangsheng Yin's avatar
Liangsheng Yin committed
314
    ) -> LogitsProcessorOutput:
315
        hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
316
        return self.logits_processor(
317
            input_ids, hidden_states, self.lm_head.weight, forward_batch
Lianmin Zheng's avatar
Lianmin Zheng committed
318
319
        )

320
    def get_hidden_dim(self, module_name):
321
        # return input_dim, output_dim
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
        if module_name in ["q_proj", "o_proj", "qkv_proj"]:
            return self.config.hidden_size, self.config.hidden_size
        elif module_name in ["kv_proj"]:
            return self.config.hidden_size, self.config.hidden_size // (
                self.config.num_attention_heads // self.config.num_key_value_heads
            )
        elif module_name == "gate_up_proj":
            return self.config.hidden_size, self.config.intermediate_size
        elif module_name == "down_proj":
            return self.config.intermediate_size, self.config.hidden_size
        else:
            raise NotImplementedError()

    def get_module_name(self, name):
        params_mapping = {
            "q_proj": "qkv_proj",
            "k_proj": "qkv_proj",
            "v_proj": "qkv_proj",
            "gate_proj": "gate_up_proj",
            "up_proj": "gate_up_proj",
        }
        return params_mapping.get(name, name)

    def get_module_name_from_weight_name(self, name):
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id, num_shard)
            ("qkv_proj", "q_proj", "q", 3),
            ("qkv_proj", "k_proj", "k", 3),
            ("qkv_proj", "v_proj", "v", 3),
            ("gate_up_proj", "gate_proj", 0, 2),
            ("gate_up_proj", "up_proj", 1, 2),
        ]
        for param_name, weight_name, shard_id, num_shard in stacked_params_mapping:
            if weight_name in name:
                return (
                    name.replace(weight_name, param_name)[: -len(".weight")],
                    num_shard,
                )
        return name[: -len(".weight")], 1

    def get_num_params(self):
        params_dict = dict(self.named_parameters())
        return len(params_dict)

366
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
Lianmin Zheng's avatar
Lianmin Zheng committed
367
368
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
369
370
371
372
373
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
Lianmin Zheng's avatar
Lianmin Zheng committed
374
        ]
375
        params_dict = dict(self.named_parameters())
376

377
        for name, loaded_weight in weights:
Lianmin Zheng's avatar
Lianmin Zheng committed
378
            if "rotary_emb.inv_freq" in name or "projector" in name:
379
                continue
Lianmin Zheng's avatar
Lianmin Zheng committed
380
381
382
            if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
383
                continue
384
            if name.startswith("model.vision_tower") and name not in params_dict:
385
                continue
386

Lianmin Zheng's avatar
Lianmin Zheng committed
387
388
389
390
391
392
393
394
395
396
397
398
399
400
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
401
                    continue
402
403
404
                # Skip loading kv_scale from ckpts towards new design.
                if name.endswith(".kv_scale") and name not in params_dict:
                    continue
Lianmin Zheng's avatar
Lianmin Zheng committed
405
406
407
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(param, loaded_weight)
Cody Yu's avatar
Cody Yu committed
408

409
410
411
412
413
414
415
416
        if (
            hasattr(self.config, "tie_word_embeddings")
            and self.config.tie_word_embeddings
        ):
            # Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing
            param = self.lm_head.weight
            weight_loader = getattr(param, "weight_loader", default_weight_loader)
            weight_loader(param, self.model.embed_tokens.weight)
417
        apply_torchao_config_(self, params_dict, set(["proj.weight"]))
418

Lianmin Zheng's avatar
Lianmin Zheng committed
419

420
421
422
423
424
class Phi3ForCausalLM(LlamaForCausalLM):
    pass


EntryClass = [LlamaForCausalLM, Phi3ForCausalLM]