llama.py 18.3 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
23
"""Inference-only LLaMA model compatible with HuggingFace weights."""
24
from typing import Any, Dict, Iterable, List, Optional, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
25
26
27
28
29

import torch
from torch import nn
from transformers import LlamaConfig

30
from vllm.attention import Attention, AttentionMetadata
31
from vllm.config import CacheConfig, LoRAConfig
32
33
from vllm.distributed import (get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
34
from vllm.model_executor.layers.activation import SiluAndMul
35
from vllm.model_executor.layers.layernorm import RMSNorm
36
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
37
38
                                               QKVParallelLinear,
                                               RowParallelLinear)
39
from vllm.model_executor.layers.logits_processor import LogitsProcessor
40
41
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
42
from vllm.model_executor.layers.rotary_embedding import get_rope
Woosuk Kwon's avatar
Woosuk Kwon committed
43
from vllm.model_executor.layers.sampler import Sampler
44
from vllm.model_executor.layers.vocab_parallel_embedding import (
45
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
46
47
from vllm.model_executor.model_loader.weight_utils import (
    default_weight_loader, kv_cache_scales_loader)
48
from vllm.model_executor.sampling_metadata import SamplingMetadata
49
from vllm.sequence import SamplerOutput
50
from vllm.utils import is_hip, print_warning_once
Woosuk Kwon's avatar
Woosuk Kwon committed
51
52
53


class LlamaMLP(nn.Module):
54

Woosuk Kwon's avatar
Woosuk Kwon committed
55
56
    def __init__(
        self,
57
58
59
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
60
        quant_config: Optional[QuantizationConfig] = None,
61
        bias: bool = False,
62
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
63
        super().__init__()
64
        self.gate_up_proj = MergedColumnParallelLinear(
65
66
            input_size=hidden_size,
            output_sizes=[intermediate_size] * 2,
67
            bias=bias,
68
            quant_config=quant_config)
69
70
        self.down_proj = RowParallelLinear(input_size=intermediate_size,
                                           output_size=hidden_size,
71
                                           bias=bias,
72
                                           quant_config=quant_config)
73
74
75
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
Woosuk Kwon's avatar
Woosuk Kwon committed
76
        self.act_fn = SiluAndMul()
Woosuk Kwon's avatar
Woosuk Kwon committed
77
78

    def forward(self, x):
79
        gate_up, _ = self.gate_up_proj(x)
Woosuk Kwon's avatar
Woosuk Kwon committed
80
        x = self.act_fn(gate_up)
Woosuk Kwon's avatar
Woosuk Kwon committed
81
82
83
84
85
86
87
88
        x, _ = self.down_proj(x)
        return x


class LlamaAttention(nn.Module):

    def __init__(
        self,
89
90
91
92
93
94
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
        max_position_embeddings: int = 8192,
95
        quant_config: Optional[QuantizationConfig] = None,
96
        bias: bool = False,
97
        cache_config: Optional[CacheConfig] = None,
98
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
99
        super().__init__()
100
        self.hidden_size = hidden_size
Zhuohan Li's avatar
Zhuohan Li committed
101
        tp_size = get_tensor_model_parallel_world_size()
102
        self.total_num_heads = num_heads
Zhuohan Li's avatar
Zhuohan Li committed
103
104
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
105
        self.total_num_kv_heads = num_kv_heads
106
107
108
109
110
111
112
113
114
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
115
        self.head_dim = hidden_size // self.total_num_heads
Zhuohan Li's avatar
Zhuohan Li committed
116
117
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
118
        self.scaling = self.head_dim**-0.5
119
120
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings
Woosuk Kwon's avatar
Woosuk Kwon committed
121

122
        self.qkv_proj = QKVParallelLinear(
123
124
125
126
            hidden_size=hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_kv_heads,
127
            bias=bias,
128
            quant_config=quant_config,
Woosuk Kwon's avatar
Woosuk Kwon committed
129
        )
130
        self.o_proj = RowParallelLinear(
131
132
            input_size=self.total_num_heads * self.head_dim,
            output_size=hidden_size,
133
            bias=bias,
134
            quant_config=quant_config,
Woosuk Kwon's avatar
Woosuk Kwon committed
135
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
136

137
138
139
140
141
142
143
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
144
145
146
147
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
                              num_kv_heads=self.num_kv_heads,
148
149
                              cache_config=cache_config,
                              quant_config=quant_config)
Woosuk Kwon's avatar
Woosuk Kwon committed
150
151
152

    def forward(
        self,
153
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
154
        hidden_states: torch.Tensor,
155
156
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
157
    ) -> torch.Tensor:
158
        qkv, _ = self.qkv_proj(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
159
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
160
        q, k = self.rotary_emb(positions, q, k)
161
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
162
163
164
165
166
167
        output, _ = self.o_proj(attn_output)
        return output


class LlamaDecoderLayer(nn.Module):

168
169
170
    def __init__(
        self,
        config: LlamaConfig,
171
        cache_config: Optional[CacheConfig] = None,
172
        quant_config: Optional[QuantizationConfig] = None,
173
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
174
175
        super().__init__()
        self.hidden_size = config.hidden_size
176
177
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
178
179
180
181
        if rope_scaling is not None and getattr(
                config, "original_max_position_embeddings", None):
            rope_scaling["original_max_position_embeddings"] = (
                config.original_max_position_embeddings)
182
183
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
184
185
186
187
        # Support abacusai/Smaug-72B-v0.1 with attention_bias
        # Support internlm/internlm-7b with bias
        attention_bias = getattr(config, "attention_bias", False) or getattr(
            config, "bias", False)
Woosuk Kwon's avatar
Woosuk Kwon committed
188
        self.self_attn = LlamaAttention(
189
190
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
191
192
            num_kv_heads=getattr(config, "num_key_value_heads",
                                 config.num_attention_heads),
193
194
195
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
196
            quant_config=quant_config,
197
            bias=attention_bias,
198
            cache_config=cache_config,
Woosuk Kwon's avatar
Woosuk Kwon committed
199
200
        )
        self.mlp = LlamaMLP(
201
202
203
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
204
            quant_config=quant_config,
205
            bias=getattr(config, "mlp_bias", False),
Woosuk Kwon's avatar
Woosuk Kwon committed
206
        )
207
208
209
210
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)
Woosuk Kwon's avatar
Woosuk Kwon committed
211
212
213

    def forward(
        self,
214
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
215
        hidden_states: torch.Tensor,
216
217
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
218
219
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
220
        # Self Attention
221
222
223
224
225
226
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
227
228
229
230
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
231
            attn_metadata=attn_metadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
232
233
234
        )

        # Fully Connected
235
236
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
237
        hidden_states = self.mlp(hidden_states)
238
        return hidden_states, residual
Woosuk Kwon's avatar
Woosuk Kwon committed
239
240
241
242


class LlamaModel(nn.Module):

243
244
245
    def __init__(
        self,
        config: LlamaConfig,
246
        cache_config: Optional[CacheConfig] = None,
247
        quant_config: Optional[QuantizationConfig] = None,
248
        lora_config: Optional[LoRAConfig] = None,
249
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
250
251
252
        super().__init__()
        self.config = config
        self.padding_idx = config.pad_token_id
253
254
255
256
        lora_vocab = (lora_config.lora_extra_vocab_size *
                      (lora_config.max_loras or 1)) if lora_config else 0
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size
257
        self.embed_tokens = VocabParallelEmbedding(
258
            self.vocab_size,
259
            config.hidden_size,
260
            org_num_embeddings=config.vocab_size,
261
        )
262
        self.layers = nn.ModuleList([
263
264
265
266
            LlamaDecoderLayer(config=config,
                              cache_config=cache_config,
                              quant_config=quant_config)
            for idx in range(config.num_hidden_layers)
267
        ])
268
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Woosuk Kwon's avatar
Woosuk Kwon committed
269

270
271
272
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

Woosuk Kwon's avatar
Woosuk Kwon committed
273
274
    def forward(
        self,
275
        input_ids: Optional[torch.Tensor],
276
        positions: torch.Tensor,
277
278
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
279
        inputs_embeds: Optional[torch.Tensor] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
280
    ) -> torch.Tensor:
281
282
283
284
        if inputs_embeds is not None:
            hidden_states = inputs_embeds
        else:
            hidden_states = self.get_input_embeddings(input_ids)
285
        residual = None
Woosuk Kwon's avatar
Woosuk Kwon committed
286
287
        for i in range(len(self.layers)):
            layer = self.layers[i]
288
            hidden_states, residual = layer(
Woosuk Kwon's avatar
Woosuk Kwon committed
289
290
291
                positions,
                hidden_states,
                kv_caches[i],
292
                attn_metadata,
293
                residual,
Woosuk Kwon's avatar
Woosuk Kwon committed
294
            )
295
        hidden_states, _ = self.norm(hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
296
297
298
299
        return hidden_states


class LlamaForCausalLM(nn.Module):
Terry's avatar
Terry committed
300
301
302
303
304
305
306
307
308
309
310
311
312
313
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

    # LoRA specific attributes
    supported_lora_modules = [
314
315
        "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens",
        "lm_head"
Terry's avatar
Terry committed
316
317
318
319
320
321
    ]
    embedding_modules = {
        "embed_tokens": "input_embeddings",
        "lm_head": "output_embeddings",
    }
    embedding_padding_modules = ["lm_head"]
322
323
324
325
326
327
328
329
    bitsandbytes_stacked_params_mapping = {
        # shard_name, weight_name, index
        "q_proj": ("qkv_proj", 0),
        "k_proj": ("qkv_proj", 1),
        "v_proj": ("qkv_proj", 2),
        "gate_proj": ("gate_up_proj", 0),
        "up_proj": ("gate_up_proj", 1),
    }
330

331
332
333
    def __init__(
        self,
        config: LlamaConfig,
334
        cache_config: Optional[CacheConfig] = None,
335
        quant_config: Optional[QuantizationConfig] = None,
336
        lora_config: Optional[LoRAConfig] = None,
337
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
338
339
        super().__init__()
        self.config = config
340
341
342
343
        self.model = LlamaModel(config,
                                cache_config,
                                quant_config,
                                lora_config=lora_config)
Terry's avatar
Terry committed
344
        self.unpadded_vocab_size = config.vocab_size
345
        if lora_config:
Terry's avatar
Terry committed
346
            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
347
        self.lm_head = ParallelLMHead(
Terry's avatar
Terry committed
348
            self.unpadded_vocab_size,
349
350
351
352
353
354
355
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
            padding_size=DEFAULT_VOCAB_PADDING_SIZE
            # We need bigger padding if using lora for kernel
            # compatibility
            if not lora_config else lora_config.lora_vocab_padding_size,
        )
356
357
        if config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
358
359
360
361
362

        logit_scale = getattr(config, "logit_scale", 1.0)
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size, logit_scale)
        self.sampler = Sampler()
Woosuk Kwon's avatar
Woosuk Kwon committed
363
364
365

    def forward(
        self,
366
367
        input_ids: torch.Tensor,
        positions: torch.Tensor,
368
369
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
370
    ) -> torch.Tensor:
371
        hidden_states = self.model(input_ids, positions, kv_caches,
372
                                   attn_metadata)
373
374
        return hidden_states

375
376
377
378
379
380
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head.weight, hidden_states,
                                       sampling_metadata)
        return logits

381
382
    def sample(
        self,
383
        logits: torch.Tensor,
384
        sampling_metadata: SamplingMetadata,
385
    ) -> Optional[SamplerOutput]:
386
        next_tokens = self.sampler(logits, sampling_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
387
388
        return next_tokens

389
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
390
391
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
392
393
394
395
396
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
Zhuohan Li's avatar
Zhuohan Li committed
397
        ]
398
        params_dict = dict(self.named_parameters())
399
        for name, loaded_weight in weights:
400
401
            if "rotary_emb.inv_freq" in name:
                continue
402
403
404
405
            if ("rotary_emb.cos_cached" in name
                    or "rotary_emb.sin_cached" in name):
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
406
                continue
407
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
Zhuohan Li's avatar
Zhuohan Li committed
408
                if weight_name not in name:
409
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
410
411
412
413
414
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
415
416
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
417
                break
418
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
419
420
421
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
422
423
424
425
426
427
428
429
430
431
432
433
434
                # Remapping the name of FP8 kv-scale.
                if name.endswith("kv_scale"):
                    remapped_kv_scale_name = name.replace(
                        ".kv_scale", ".attn.kv_scale")
                    if remapped_kv_scale_name not in params_dict:
                        print_warning_once(
                            f"Found kv scale in the checkpoint (e.g. {name}), "
                            "but not found the expected name in the model "
                            f"(e.g. {remapped_kv_scale_name}). kv-scale is "
                            "not loaded.")
                        continue
                    else:
                        name = remapped_kv_scale_name
435
436
437
438
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458

    # If this function is called, it should always initialize KV cache scale
    # factors (or else raise an exception). Thus, handled exceptions should
    # make sure to leave KV cache scale factors in a known good (dummy) state
    def load_kv_cache_scales(self, quantization_param_path: str) -> None:
        tp_size = get_tensor_model_parallel_world_size()
        tp_rank = get_tensor_model_parallel_rank()
        for layer_idx, scaling_factor in kv_cache_scales_loader(
                quantization_param_path, tp_rank, tp_size,
                self.config.num_hidden_layers,
                self.config.__class__.model_type):
            layer_self_attn = self.model.layers[layer_idx].self_attn

            if is_hip():
                # The scaling factor convention we are assuming is
                # quantized_value * scaling_factor ~= true_value
                # which is consistent with the practice of setting
                # scaling_factor = tensor_amax / FPtype_max
                scaling_factor *= 2
            if hasattr(layer_self_attn, "kv_scale"):
459
                layer_self_attn.attn._kv_scale = scaling_factor
460
461
462
            else:
                raise RuntimeError("Self attention has no KV cache scaling "
                                   "factor attribute!")