llama.py 17.7 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
23
"""Inference-only LLaMA model compatible with HuggingFace weights."""
24
from typing import Any, Dict, Iterable, List, Optional, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
25
26
27
28
29

import torch
from torch import nn
from transformers import LlamaConfig

30
from vllm.attention import Attention, AttentionMetadata
31
from vllm.config import CacheConfig, LoRAConfig
32
33
from vllm.distributed import (get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
34
from vllm.model_executor.layers.activation import SiluAndMul
35
from vllm.model_executor.layers.layernorm import RMSNorm
36
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
37
38
                                               QKVParallelLinear,
                                               RowParallelLinear)
39
from vllm.model_executor.layers.logits_processor import LogitsProcessor
40
41
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
42
from vllm.model_executor.layers.rotary_embedding import get_rope
Woosuk Kwon's avatar
Woosuk Kwon committed
43
from vllm.model_executor.layers.sampler import Sampler
44
from vllm.model_executor.layers.vocab_parallel_embedding import (
45
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
46
47
from vllm.model_executor.model_loader.weight_utils import (
    default_weight_loader, kv_cache_scales_loader)
48
from vllm.model_executor.sampling_metadata import SamplingMetadata
49
from vllm.sequence import SamplerOutput
50
from vllm.utils import is_hip
Woosuk Kwon's avatar
Woosuk Kwon committed
51
52
53


class LlamaMLP(nn.Module):
54

Woosuk Kwon's avatar
Woosuk Kwon committed
55
56
    def __init__(
        self,
57
58
59
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
60
        quant_config: Optional[QKVParallelLinear] = None,
61
        bias: bool = False,
62
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
63
        super().__init__()
64
        self.gate_up_proj = MergedColumnParallelLinear(
65
            hidden_size, [intermediate_size] * 2,
66
            bias=bias,
67
            quant_config=quant_config)
68
69
        self.down_proj = RowParallelLinear(intermediate_size,
                                           hidden_size,
70
                                           bias=bias,
71
                                           quant_config=quant_config)
72
73
74
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
Woosuk Kwon's avatar
Woosuk Kwon committed
75
        self.act_fn = SiluAndMul()
Woosuk Kwon's avatar
Woosuk Kwon committed
76
77

    def forward(self, x):
78
        gate_up, _ = self.gate_up_proj(x)
Woosuk Kwon's avatar
Woosuk Kwon committed
79
        x = self.act_fn(gate_up)
Woosuk Kwon's avatar
Woosuk Kwon committed
80
81
82
83
84
85
86
87
        x, _ = self.down_proj(x)
        return x


class LlamaAttention(nn.Module):

    def __init__(
        self,
88
89
90
91
92
93
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
        max_position_embeddings: int = 8192,
94
        quant_config: Optional[QuantizationConfig] = None,
95
        bias: bool = False,
96
        sliding_window: Optional[int] = None,
97
        cache_config: Optional[CacheConfig] = None,
98
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
99
        super().__init__()
100
        self.hidden_size = hidden_size
Zhuohan Li's avatar
Zhuohan Li committed
101
        tp_size = get_tensor_model_parallel_world_size()
102
        self.total_num_heads = num_heads
Zhuohan Li's avatar
Zhuohan Li committed
103
104
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
105
        self.total_num_kv_heads = num_kv_heads
106
107
108
109
110
111
112
113
114
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
115
        self.head_dim = hidden_size // self.total_num_heads
Zhuohan Li's avatar
Zhuohan Li committed
116
117
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
118
        self.scaling = self.head_dim**-0.5
119
120
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings
Woosuk Kwon's avatar
Woosuk Kwon committed
121

122
123
124
125
126
127
128
129
130
        # This will be overwritten by model initialization if we are using it.
        # N.B. currently we only support per tensor scalar scaling factors
        # & only applicable to ROCm (AMD GPU).
        # The scaling factor convention we are assuming is
        # quantized_value * scaling_factor ~= true_value
        # which is consistent with the practice of setting
        # scaling_factor = tensor_amax / FPtype_max
        self.kv_scale = 1.0

131
        self.qkv_proj = QKVParallelLinear(
132
            hidden_size,
Zhuohan Li's avatar
Zhuohan Li committed
133
            self.head_dim,
134
135
            self.total_num_heads,
            self.total_num_kv_heads,
136
            bias=bias,
137
            quant_config=quant_config,
Woosuk Kwon's avatar
Woosuk Kwon committed
138
        )
139
        self.o_proj = RowParallelLinear(
Woosuk Kwon's avatar
Woosuk Kwon committed
140
            self.total_num_heads * self.head_dim,
141
            hidden_size,
142
            bias=bias,
143
            quant_config=quant_config,
Woosuk Kwon's avatar
Woosuk Kwon committed
144
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
145

146
147
148
149
150
151
152
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
153
154
155
156
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
                              num_kv_heads=self.num_kv_heads,
157
158
                              sliding_window=sliding_window,
                              cache_config=cache_config)
Woosuk Kwon's avatar
Woosuk Kwon committed
159
160
161

    def forward(
        self,
162
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
163
        hidden_states: torch.Tensor,
164
165
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
166
    ) -> torch.Tensor:
167
        qkv, _ = self.qkv_proj(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
168
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
169
        q, k = self.rotary_emb(positions, q, k)
170
171
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata,
                                self.kv_scale)
Woosuk Kwon's avatar
Woosuk Kwon committed
172
173
174
175
176
177
        output, _ = self.o_proj(attn_output)
        return output


class LlamaDecoderLayer(nn.Module):

178
179
180
    def __init__(
        self,
        config: LlamaConfig,
181
        cache_config: Optional[CacheConfig] = None,
182
        quant_config: Optional[QuantizationConfig] = None,
183
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
184
185
        super().__init__()
        self.hidden_size = config.hidden_size
186
187
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
188
189
190
191
        if rope_scaling is not None and getattr(
                config, "original_max_position_embeddings", None):
            rope_scaling["original_max_position_embeddings"] = (
                config.original_max_position_embeddings)
192
193
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
194
        sliding_window = getattr(config, "sliding_window", None)
195
196
197
198
        # Support abacusai/Smaug-72B-v0.1 with attention_bias
        # Support internlm/internlm-7b with bias
        attention_bias = getattr(config, "attention_bias", False) or getattr(
            config, "bias", False)
Woosuk Kwon's avatar
Woosuk Kwon committed
199
        self.self_attn = LlamaAttention(
200
201
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
202
203
            num_kv_heads=getattr(config, "num_key_value_heads",
                                 config.num_attention_heads),
204
205
206
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
207
            quant_config=quant_config,
208
            bias=attention_bias,
209
            sliding_window=sliding_window,
210
            cache_config=cache_config,
Woosuk Kwon's avatar
Woosuk Kwon committed
211
212
        )
        self.mlp = LlamaMLP(
213
214
215
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
216
            quant_config=quant_config,
217
            bias=getattr(config, "mlp_bias", False),
Woosuk Kwon's avatar
Woosuk Kwon committed
218
        )
219
220
221
222
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)
Woosuk Kwon's avatar
Woosuk Kwon committed
223
224
225

    def forward(
        self,
226
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
227
        hidden_states: torch.Tensor,
228
229
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
230
231
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
232
        # Self Attention
233
234
235
236
237
238
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
239
240
241
242
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
243
            attn_metadata=attn_metadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
244
245
246
        )

        # Fully Connected
247
248
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
249
        hidden_states = self.mlp(hidden_states)
250
        return hidden_states, residual
Woosuk Kwon's avatar
Woosuk Kwon committed
251
252
253
254


class LlamaModel(nn.Module):

255
256
257
    def __init__(
        self,
        config: LlamaConfig,
258
        cache_config: Optional[CacheConfig] = None,
259
        quant_config: Optional[QuantizationConfig] = None,
260
        lora_config: Optional[LoRAConfig] = None,
261
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
262
263
264
        super().__init__()
        self.config = config
        self.padding_idx = config.pad_token_id
265
266
267
268
        lora_vocab = (lora_config.lora_extra_vocab_size *
                      (lora_config.max_loras or 1)) if lora_config else 0
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size
269
        self.embed_tokens = VocabParallelEmbedding(
270
            self.vocab_size,
271
            config.hidden_size,
272
            org_num_embeddings=config.vocab_size,
273
        )
274
        self.layers = nn.ModuleList([
275
            LlamaDecoderLayer(config, cache_config, quant_config)
276
            for _ in range(config.num_hidden_layers)
277
        ])
278
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Woosuk Kwon's avatar
Woosuk Kwon committed
279

280
281
282
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

Woosuk Kwon's avatar
Woosuk Kwon committed
283
284
    def forward(
        self,
285
        input_ids: Optional[torch.Tensor],
286
        positions: torch.Tensor,
287
288
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
289
        inputs_embeds: Optional[torch.Tensor] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
290
    ) -> torch.Tensor:
291
292
293
294
        if inputs_embeds is not None:
            hidden_states = inputs_embeds
        else:
            hidden_states = self.get_input_embeddings(input_ids)
295
        residual = None
Woosuk Kwon's avatar
Woosuk Kwon committed
296
297
        for i in range(len(self.layers)):
            layer = self.layers[i]
298
            hidden_states, residual = layer(
Woosuk Kwon's avatar
Woosuk Kwon committed
299
300
301
                positions,
                hidden_states,
                kv_caches[i],
302
                attn_metadata,
303
                residual,
Woosuk Kwon's avatar
Woosuk Kwon committed
304
            )
305
        hidden_states, _ = self.norm(hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
306
307
308
309
        return hidden_states


class LlamaForCausalLM(nn.Module):
Terry's avatar
Terry committed
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

    # LoRA specific attributes
    supported_lora_modules = [
        "qkv_proj",
        "o_proj",
        "gate_up_proj",
        "down_proj",
        "embed_tokens",
        "lm_head",
    ]
    embedding_modules = {
        "embed_tokens": "input_embeddings",
        "lm_head": "output_embeddings",
    }
    embedding_padding_modules = ["lm_head"]
336

337
338
339
    def __init__(
        self,
        config: LlamaConfig,
340
        cache_config: Optional[CacheConfig] = None,
341
        quant_config: Optional[QuantizationConfig] = None,
342
        lora_config: Optional[LoRAConfig] = None,
343
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
344
345
        super().__init__()
        self.config = config
346
347
348
349
        self.model = LlamaModel(config,
                                cache_config,
                                quant_config,
                                lora_config=lora_config)
Terry's avatar
Terry committed
350
        self.unpadded_vocab_size = config.vocab_size
351
        if lora_config:
Terry's avatar
Terry committed
352
            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
353
        self.lm_head = ParallelLMHead(
Terry's avatar
Terry committed
354
            self.unpadded_vocab_size,
355
356
357
358
359
360
361
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
            padding_size=DEFAULT_VOCAB_PADDING_SIZE
            # We need bigger padding if using lora for kernel
            # compatibility
            if not lora_config else lora_config.lora_vocab_padding_size,
        )
362
363
        if config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
364
365
366
367
368

        logit_scale = getattr(config, "logit_scale", 1.0)
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size, logit_scale)
        self.sampler = Sampler()
Woosuk Kwon's avatar
Woosuk Kwon committed
369
370
371

    def forward(
        self,
372
373
        input_ids: torch.Tensor,
        positions: torch.Tensor,
374
375
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
376
    ) -> torch.Tensor:
377
        hidden_states = self.model(input_ids, positions, kv_caches,
378
                                   attn_metadata)
379
380
        return hidden_states

381
382
383
384
385
386
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head.weight, hidden_states,
                                       sampling_metadata)
        return logits

387
388
    def sample(
        self,
389
        logits: torch.Tensor,
390
        sampling_metadata: SamplingMetadata,
391
    ) -> Optional[SamplerOutput]:
392
        next_tokens = self.sampler(logits, sampling_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
393
394
        return next_tokens

395
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
396
397
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
398
399
400
401
402
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
Zhuohan Li's avatar
Zhuohan Li committed
403
        ]
404
        params_dict = dict(self.named_parameters())
405
        for name, loaded_weight in weights:
406
407
            if "rotary_emb.inv_freq" in name:
                continue
408
409
410
411
            if ("rotary_emb.cos_cached" in name
                    or "rotary_emb.sin_cached" in name):
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
412
                continue
413
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
Zhuohan Li's avatar
Zhuohan Li committed
414
                if weight_name not in name:
415
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
416
417
418
419
420
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
421
422
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
423
                break
424
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
425
426
427
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
428
429
430
431
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455

    # If this function is called, it should always initialize KV cache scale
    # factors (or else raise an exception). Thus, handled exceptions should
    # make sure to leave KV cache scale factors in a known good (dummy) state
    def load_kv_cache_scales(self, quantization_param_path: str) -> None:
        tp_size = get_tensor_model_parallel_world_size()
        tp_rank = get_tensor_model_parallel_rank()
        for layer_idx, scaling_factor in kv_cache_scales_loader(
                quantization_param_path, tp_rank, tp_size,
                self.config.num_hidden_layers,
                self.config.__class__.model_type):
            layer_self_attn = self.model.layers[layer_idx].self_attn

            if is_hip():
                # The scaling factor convention we are assuming is
                # quantized_value * scaling_factor ~= true_value
                # which is consistent with the practice of setting
                # scaling_factor = tensor_amax / FPtype_max
                scaling_factor *= 2
            if hasattr(layer_self_attn, "kv_scale"):
                layer_self_attn.kv_scale = scaling_factor
            else:
                raise RuntimeError("Self attention has no KV cache scaling "
                                   "factor attribute!")