llama.py 16.9 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
23
"""Inference-only LLaMA model compatible with HuggingFace weights."""
24
from typing import Any, Dict, Iterable, List, Optional, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
25
26
27
28
29

import torch
from torch import nn
from transformers import LlamaConfig

30
from vllm.attention import Attention, AttentionMetadata
31
from vllm.config import LoRAConfig
32
33
from vllm.distributed import (get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
34
from vllm.model_executor.layers.activation import SiluAndMul
35
36
37
38
39
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
                                               MergedColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
40
from vllm.model_executor.layers.logits_processor import LogitsProcessor
41
from vllm.model_executor.layers.rotary_embedding import get_rope
Woosuk Kwon's avatar
Woosuk Kwon committed
42
from vllm.model_executor.layers.sampler import Sampler
43
from vllm.model_executor.layers.vocab_parallel_embedding import (
44
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
45
46
from vllm.model_executor.model_loader.weight_utils import (
    default_weight_loader, kv_cache_scales_loader)
47
from vllm.model_executor.sampling_metadata import SamplingMetadata
48
from vllm.sequence import SamplerOutput
49
from vllm.utils import is_hip
Woosuk Kwon's avatar
Woosuk Kwon committed
50
51
52


class LlamaMLP(nn.Module):
53

Woosuk Kwon's avatar
Woosuk Kwon committed
54
55
    def __init__(
        self,
56
57
58
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
59
        linear_method: Optional[LinearMethodBase] = None,
60
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
61
        super().__init__()
62
        self.gate_up_proj = MergedColumnParallelLinear(
63
            hidden_size, [intermediate_size] * 2,
64
65
            bias=False,
            linear_method=linear_method)
66
67
        self.down_proj = RowParallelLinear(intermediate_size,
                                           hidden_size,
68
69
                                           bias=False,
                                           linear_method=linear_method)
70
71
72
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
Woosuk Kwon's avatar
Woosuk Kwon committed
73
        self.act_fn = SiluAndMul()
Woosuk Kwon's avatar
Woosuk Kwon committed
74
75

    def forward(self, x):
76
        gate_up, _ = self.gate_up_proj(x)
Woosuk Kwon's avatar
Woosuk Kwon committed
77
        x = self.act_fn(gate_up)
Woosuk Kwon's avatar
Woosuk Kwon committed
78
79
80
81
82
83
84
85
        x, _ = self.down_proj(x)
        return x


class LlamaAttention(nn.Module):

    def __init__(
        self,
86
87
88
89
90
91
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
        max_position_embeddings: int = 8192,
92
        linear_method: Optional[LinearMethodBase] = None,
93
        bias: bool = False,
94
        sliding_window: Optional[int] = None,
95
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
96
        super().__init__()
97
        self.hidden_size = hidden_size
Zhuohan Li's avatar
Zhuohan Li committed
98
        tp_size = get_tensor_model_parallel_world_size()
99
        self.total_num_heads = num_heads
Zhuohan Li's avatar
Zhuohan Li committed
100
101
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
102
        self.total_num_kv_heads = num_kv_heads
103
104
105
106
107
108
109
110
111
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
112
        self.head_dim = hidden_size // self.total_num_heads
Zhuohan Li's avatar
Zhuohan Li committed
113
114
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
115
        self.scaling = self.head_dim**-0.5
116
117
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings
Woosuk Kwon's avatar
Woosuk Kwon committed
118

119
120
121
122
123
124
125
126
127
        # This will be overwritten by model initialization if we are using it.
        # N.B. currently we only support per tensor scalar scaling factors
        # & only applicable to ROCm (AMD GPU).
        # The scaling factor convention we are assuming is
        # quantized_value * scaling_factor ~= true_value
        # which is consistent with the practice of setting
        # scaling_factor = tensor_amax / FPtype_max
        self.kv_scale = 1.0

128
        self.qkv_proj = QKVParallelLinear(
129
            hidden_size,
Zhuohan Li's avatar
Zhuohan Li committed
130
            self.head_dim,
131
132
            self.total_num_heads,
            self.total_num_kv_heads,
133
            bias=bias,
134
            linear_method=linear_method,
Woosuk Kwon's avatar
Woosuk Kwon committed
135
        )
136
        self.o_proj = RowParallelLinear(
Woosuk Kwon's avatar
Woosuk Kwon committed
137
            self.total_num_heads * self.head_dim,
138
            hidden_size,
139
            bias=bias,
140
            linear_method=linear_method,
Woosuk Kwon's avatar
Woosuk Kwon committed
141
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
142

143
144
145
146
147
148
149
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
150
151
152
153
154
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
                              num_kv_heads=self.num_kv_heads,
                              sliding_window=sliding_window)
Woosuk Kwon's avatar
Woosuk Kwon committed
155
156
157

    def forward(
        self,
158
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
159
        hidden_states: torch.Tensor,
160
161
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
162
    ) -> torch.Tensor:
163
        qkv, _ = self.qkv_proj(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
164
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
165
        q, k = self.rotary_emb(positions, q, k)
166
167
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata,
                                self.kv_scale)
Woosuk Kwon's avatar
Woosuk Kwon committed
168
169
170
171
172
173
        output, _ = self.o_proj(attn_output)
        return output


class LlamaDecoderLayer(nn.Module):

174
175
176
    def __init__(
        self,
        config: LlamaConfig,
177
        linear_method: Optional[LinearMethodBase] = None,
178
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
179
180
        super().__init__()
        self.hidden_size = config.hidden_size
181
182
183
184
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
185
        sliding_window = getattr(config, "sliding_window", None)
186
187
188
189
        # Support abacusai/Smaug-72B-v0.1 with attention_bias
        # Support internlm/internlm-7b with bias
        attention_bias = getattr(config, "attention_bias", False) or getattr(
            config, "bias", False)
Woosuk Kwon's avatar
Woosuk Kwon committed
190
        self.self_attn = LlamaAttention(
191
192
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
193
194
            num_kv_heads=getattr(config, "num_key_value_heads",
                                 config.num_attention_heads),
195
196
197
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
198
            linear_method=linear_method,
199
            bias=attention_bias,
200
            sliding_window=sliding_window,
Woosuk Kwon's avatar
Woosuk Kwon committed
201
202
        )
        self.mlp = LlamaMLP(
203
204
205
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
206
            linear_method=linear_method,
Woosuk Kwon's avatar
Woosuk Kwon committed
207
        )
208
209
210
211
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)
Woosuk Kwon's avatar
Woosuk Kwon committed
212
213
214

    def forward(
        self,
215
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
216
        hidden_states: torch.Tensor,
217
218
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
219
220
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
221
        # Self Attention
222
223
224
225
226
227
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
228
229
230
231
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
232
            attn_metadata=attn_metadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
233
234
235
        )

        # Fully Connected
236
237
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
238
        hidden_states = self.mlp(hidden_states)
239
        return hidden_states, residual
Woosuk Kwon's avatar
Woosuk Kwon committed
240
241
242
243


class LlamaModel(nn.Module):

244
245
246
    def __init__(
        self,
        config: LlamaConfig,
247
        linear_method: Optional[LinearMethodBase] = None,
248
        lora_config: Optional[LoRAConfig] = None,
249
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
250
251
252
        super().__init__()
        self.config = config
        self.padding_idx = config.pad_token_id
253
254
255
256
        lora_vocab = (lora_config.lora_extra_vocab_size *
                      (lora_config.max_loras or 1)) if lora_config else 0
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size
257
        self.embed_tokens = VocabParallelEmbedding(
258
            self.vocab_size,
259
            config.hidden_size,
260
            org_num_embeddings=config.vocab_size,
261
        )
262
        self.layers = nn.ModuleList([
263
            LlamaDecoderLayer(config, linear_method)
264
            for _ in range(config.num_hidden_layers)
265
        ])
266
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Woosuk Kwon's avatar
Woosuk Kwon committed
267

268
269
270
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

Woosuk Kwon's avatar
Woosuk Kwon committed
271
272
    def forward(
        self,
273
        input_ids: Optional[torch.Tensor],
274
        positions: torch.Tensor,
275
276
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
277
        inputs_embeds: Optional[torch.Tensor] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
278
    ) -> torch.Tensor:
279
280
281
282
        if inputs_embeds is not None:
            hidden_states = inputs_embeds
        else:
            hidden_states = self.get_input_embeddings(input_ids)
283
        residual = None
Woosuk Kwon's avatar
Woosuk Kwon committed
284
285
        for i in range(len(self.layers)):
            layer = self.layers[i]
286
            hidden_states, residual = layer(
Woosuk Kwon's avatar
Woosuk Kwon committed
287
288
289
                positions,
                hidden_states,
                kv_caches[i],
290
                attn_metadata,
291
                residual,
Woosuk Kwon's avatar
Woosuk Kwon committed
292
            )
293
        hidden_states, _ = self.norm(hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
294
295
296
297
        return hidden_states


class LlamaForCausalLM(nn.Module):
Terry's avatar
Terry committed
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

    # LoRA specific attributes
    supported_lora_modules = [
        "qkv_proj",
        "o_proj",
        "gate_up_proj",
        "down_proj",
        "embed_tokens",
        "lm_head",
    ]
    embedding_modules = {
        "embed_tokens": "input_embeddings",
        "lm_head": "output_embeddings",
    }
    embedding_padding_modules = ["lm_head"]
324

325
326
327
    def __init__(
        self,
        config: LlamaConfig,
328
        linear_method: Optional[LinearMethodBase] = None,
329
        lora_config: Optional[LoRAConfig] = None,
330
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
331
332
        super().__init__()
        self.config = config
333
        self.linear_method = linear_method
334
        self.model = LlamaModel(config, linear_method, lora_config=lora_config)
Terry's avatar
Terry committed
335
        self.unpadded_vocab_size = config.vocab_size
336
        if lora_config:
Terry's avatar
Terry committed
337
            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
338
        self.lm_head = ParallelLMHead(
Terry's avatar
Terry committed
339
            self.unpadded_vocab_size,
340
341
342
343
344
345
346
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
            padding_size=DEFAULT_VOCAB_PADDING_SIZE
            # We need bigger padding if using lora for kernel
            # compatibility
            if not lora_config else lora_config.lora_vocab_padding_size,
        )
347
348
349
350
351

        logit_scale = getattr(config, "logit_scale", 1.0)
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size, logit_scale)
        self.sampler = Sampler()
Woosuk Kwon's avatar
Woosuk Kwon committed
352
353
354

    def forward(
        self,
355
356
        input_ids: torch.Tensor,
        positions: torch.Tensor,
357
358
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
359
    ) -> torch.Tensor:
360
        hidden_states = self.model(input_ids, positions, kv_caches,
361
                                   attn_metadata)
362
363
        return hidden_states

364
365
366
367
368
369
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head.weight, hidden_states,
                                       sampling_metadata)
        return logits

370
371
    def sample(
        self,
372
        logits: torch.Tensor,
373
        sampling_metadata: SamplingMetadata,
374
    ) -> Optional[SamplerOutput]:
375
        next_tokens = self.sampler(logits, sampling_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
376
377
        return next_tokens

378
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
379
380
381
382
383
384
385
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
Zhuohan Li's avatar
Zhuohan Li committed
386
        ]
387
        params_dict = dict(self.named_parameters())
388
        for name, loaded_weight in weights:
389
390
            if "rotary_emb.inv_freq" in name:
                continue
391
392
393
394
            if ("rotary_emb.cos_cached" in name
                    or "rotary_emb.sin_cached" in name):
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
395
                continue
396
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
Zhuohan Li's avatar
Zhuohan Li committed
397
                if weight_name not in name:
398
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
399
400
401
402
403
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
404
405
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
406
                break
407
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
408
409
410
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
411
412
413
414
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438

    # If this function is called, it should always initialize KV cache scale
    # factors (or else raise an exception). Thus, handled exceptions should
    # make sure to leave KV cache scale factors in a known good (dummy) state
    def load_kv_cache_scales(self, quantization_param_path: str) -> None:
        tp_size = get_tensor_model_parallel_world_size()
        tp_rank = get_tensor_model_parallel_rank()
        for layer_idx, scaling_factor in kv_cache_scales_loader(
                quantization_param_path, tp_rank, tp_size,
                self.config.num_hidden_layers,
                self.config.__class__.model_type):
            layer_self_attn = self.model.layers[layer_idx].self_attn

            if is_hip():
                # The scaling factor convention we are assuming is
                # quantized_value * scaling_factor ~= true_value
                # which is consistent with the practice of setting
                # scaling_factor = tensor_amax / FPtype_max
                scaling_factor *= 2
            if hasattr(layer_self_attn, "kv_scale"):
                layer_self_attn.kv_scale = scaling_factor
            else:
                raise RuntimeError("Self attention has no KV cache scaling "
                                   "factor attribute!")