llama.py 20 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
23
"""Inference-only LLaMA model compatible with HuggingFace weights."""
24
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
25
26
27
28

import torch
from torch import nn
from transformers import LlamaConfig
zhuwenwen's avatar
zhuwenwen committed
29
import os
Woosuk Kwon's avatar
Woosuk Kwon committed
30

31
from vllm.attention import Attention, AttentionMetadata
32
from vllm.config import CacheConfig, LoRAConfig
33
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
34
                              get_tensor_model_parallel_world_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
35
from vllm.model_executor.layers.activation import SiluAndMul
36
from vllm.model_executor.layers.layernorm import RMSNorm
37
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
38
39
                                               QKVParallelLinear,
                                               RowParallelLinear)
40
from vllm.model_executor.layers.logits_processor import LogitsProcessor
41
42
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
43
from vllm.model_executor.layers.rotary_embedding import get_rope
Woosuk Kwon's avatar
Woosuk Kwon committed
44
from vllm.model_executor.layers.sampler import Sampler
45
from vllm.model_executor.layers.vocab_parallel_embedding import (
46
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
47
48
from vllm.model_executor.model_loader.weight_utils import (
    default_weight_loader, kv_cache_scales_loader)
49
from vllm.model_executor.sampling_metadata import SamplingMetadata
50
from vllm.sequence import IntermediateTensors, SamplerOutput
51
from vllm.utils import is_hip, print_warning_once
Woosuk Kwon's avatar
Woosuk Kwon committed
52

53
from .interfaces import SupportsLoRA
54
from .utils import is_pp_missing_parameter, make_layers
55

Woosuk Kwon's avatar
Woosuk Kwon committed
56
57

class LlamaMLP(nn.Module):
58

Woosuk Kwon's avatar
Woosuk Kwon committed
59
60
    def __init__(
        self,
61
62
63
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
64
        quant_config: Optional[QuantizationConfig] = None,
65
        bias: bool = False,
66
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
67
        super().__init__()
68
        self.gate_up_proj = MergedColumnParallelLinear(
69
70
            input_size=hidden_size,
            output_sizes=[intermediate_size] * 2,
71
            bias=bias,
72
            quant_config=quant_config)
73
74
        self.down_proj = RowParallelLinear(input_size=intermediate_size,
                                           output_size=hidden_size,
75
                                           bias=bias,
76
                                           quant_config=quant_config)
77
78
79
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
Woosuk Kwon's avatar
Woosuk Kwon committed
80
        self.act_fn = SiluAndMul()
Woosuk Kwon's avatar
Woosuk Kwon committed
81
82

    def forward(self, x):
83
        gate_up, _ = self.gate_up_proj(x)
Woosuk Kwon's avatar
Woosuk Kwon committed
84
        x = self.act_fn(gate_up)
Woosuk Kwon's avatar
Woosuk Kwon committed
85
86
87
88
89
90
91
92
        x, _ = self.down_proj(x)
        return x


class LlamaAttention(nn.Module):

    def __init__(
        self,
93
94
95
96
97
98
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
        max_position_embeddings: int = 8192,
99
        quant_config: Optional[QuantizationConfig] = None,
100
        bias: bool = False,
101
        cache_config: Optional[CacheConfig] = None,
102
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
103
        super().__init__()
104
        self.hidden_size = hidden_size
Zhuohan Li's avatar
Zhuohan Li committed
105
        tp_size = get_tensor_model_parallel_world_size()
106
        self.total_num_heads = num_heads
Zhuohan Li's avatar
Zhuohan Li committed
107
108
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
109
        self.total_num_kv_heads = num_kv_heads
110
111
112
113
114
115
116
117
118
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
119
        self.head_dim = hidden_size // self.total_num_heads
Zhuohan Li's avatar
Zhuohan Li committed
120
121
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
122
        self.scaling = self.head_dim**-0.5
123
124
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings
Woosuk Kwon's avatar
Woosuk Kwon committed
125

126
        self.qkv_proj = QKVParallelLinear(
127
128
129
130
            hidden_size=hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_kv_heads,
131
            bias=bias,
132
            quant_config=quant_config,
Woosuk Kwon's avatar
Woosuk Kwon committed
133
        )
134
        self.o_proj = RowParallelLinear(
135
136
            input_size=self.total_num_heads * self.head_dim,
            output_size=hidden_size,
137
            bias=bias,
138
            quant_config=quant_config,
Woosuk Kwon's avatar
Woosuk Kwon committed
139
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
140

141
142
143
144
145
146
147
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
148
149
150
151
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
                              num_kv_heads=self.num_kv_heads,
152
153
                              cache_config=cache_config,
                              quant_config=quant_config)
Woosuk Kwon's avatar
Woosuk Kwon committed
154
155
156

    def forward(
        self,
157
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
158
        hidden_states: torch.Tensor,
159
160
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
161
    ) -> torch.Tensor:
162
        qkv, _ = self.qkv_proj(hidden_states)
zhuwenwen's avatar
zhuwenwen committed
163
        if os.environ.get('FA_PAD') == '1' and qkv.shape[-1] == 12320:
zhuwenwen's avatar
zhuwenwen committed
164
            qkv = qkv[...,:-32]
Zhuohan Li's avatar
Zhuohan Li committed
165
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
166
        q, k = self.rotary_emb(positions, q, k)
167
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
168
169
170
171
172
173
        output, _ = self.o_proj(attn_output)
        return output


class LlamaDecoderLayer(nn.Module):

174
175
176
    def __init__(
        self,
        config: LlamaConfig,
177
        cache_config: Optional[CacheConfig] = None,
178
        quant_config: Optional[QuantizationConfig] = None,
179
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
180
181
        super().__init__()
        self.hidden_size = config.hidden_size
182
183
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
184
185
186
187
        if rope_scaling is not None and getattr(
                config, "original_max_position_embeddings", None):
            rope_scaling["original_max_position_embeddings"] = (
                config.original_max_position_embeddings)
188
189
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
190
191
192
193
        # Support abacusai/Smaug-72B-v0.1 with attention_bias
        # Support internlm/internlm-7b with bias
        attention_bias = getattr(config, "attention_bias", False) or getattr(
            config, "bias", False)
Woosuk Kwon's avatar
Woosuk Kwon committed
194
        self.self_attn = LlamaAttention(
195
196
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
197
198
            num_kv_heads=getattr(config, "num_key_value_heads",
                                 config.num_attention_heads),
199
200
201
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
202
            quant_config=quant_config,
203
            bias=attention_bias,
204
            cache_config=cache_config,
Woosuk Kwon's avatar
Woosuk Kwon committed
205
206
        )
        self.mlp = LlamaMLP(
207
208
209
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
210
            quant_config=quant_config,
211
            bias=getattr(config, "mlp_bias", False),
Woosuk Kwon's avatar
Woosuk Kwon committed
212
        )
213
214
215
216
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)
Woosuk Kwon's avatar
Woosuk Kwon committed
217
218
219

    def forward(
        self,
220
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
221
        hidden_states: torch.Tensor,
222
223
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
224
225
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
226
        # Self Attention
227
228
229
230
231
232
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
233
234
235
236
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
237
            attn_metadata=attn_metadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
238
239
240
        )

        # Fully Connected
241
242
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
243
        hidden_states = self.mlp(hidden_states)
244
        return hidden_states, residual
Woosuk Kwon's avatar
Woosuk Kwon committed
245
246
247
248


class LlamaModel(nn.Module):

249
250
251
    def __init__(
        self,
        config: LlamaConfig,
252
        cache_config: Optional[CacheConfig] = None,
253
        quant_config: Optional[QuantizationConfig] = None,
254
        lora_config: Optional[LoRAConfig] = None,
255
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
256
257
258
        super().__init__()
        self.config = config
        self.padding_idx = config.pad_token_id
259
260
261
262
        lora_vocab = (lora_config.lora_extra_vocab_size *
                      (lora_config.max_loras or 1)) if lora_config else 0
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size
263
        self.embed_tokens = VocabParallelEmbedding(
264
            self.vocab_size,
265
            config.hidden_size,
266
            org_num_embeddings=config.vocab_size,
267
        )
268
        self.start_layer, self.end_layer, self.layers = make_layers(
269
            config.num_hidden_layers,
270
271
272
            lambda: LlamaDecoderLayer(config=config,
                                      cache_config=cache_config,
                                      quant_config=quant_config))
273
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Woosuk Kwon's avatar
Woosuk Kwon committed
274

275
276
277
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

Woosuk Kwon's avatar
Woosuk Kwon committed
278
279
    def forward(
        self,
280
        input_ids: Optional[torch.Tensor],
281
        positions: torch.Tensor,
282
283
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
284
        intermediate_tensors: Optional[IntermediateTensors],
285
        inputs_embeds: Optional[torch.Tensor] = None,
286
287
288
289
290
291
292
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
            residual = None
293
        else:
294
295
296
297
298
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

        for i in range(self.start_layer, self.end_layer):
Woosuk Kwon's avatar
Woosuk Kwon committed
299
            layer = self.layers[i]
300
            hidden_states, residual = layer(
Woosuk Kwon's avatar
Woosuk Kwon committed
301
302
                positions,
                hidden_states,
303
                kv_caches[i - self.start_layer],
304
                attn_metadata,
305
                residual,
Woosuk Kwon's avatar
Woosuk Kwon committed
306
            )
307
308
309
310
311
312
313

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })

314
        hidden_states, _ = self.norm(hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
315
316
317
        return hidden_states


318
class LlamaForCausalLM(nn.Module, SupportsLoRA):
Terry's avatar
Terry committed
319
320
321
322
323
324
325
326
327
328
329
330
331
332
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

    # LoRA specific attributes
    supported_lora_modules = [
333
334
        "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens",
        "lm_head"
Terry's avatar
Terry committed
335
336
337
338
339
340
    ]
    embedding_modules = {
        "embed_tokens": "input_embeddings",
        "lm_head": "output_embeddings",
    }
    embedding_padding_modules = ["lm_head"]
341
342
343
344
345
346
347
348
    bitsandbytes_stacked_params_mapping = {
        # shard_name, weight_name, index
        "q_proj": ("qkv_proj", 0),
        "k_proj": ("qkv_proj", 1),
        "v_proj": ("qkv_proj", 2),
        "gate_proj": ("gate_up_proj", 0),
        "up_proj": ("gate_up_proj", 1),
    }
349

350
351
352
    def __init__(
        self,
        config: LlamaConfig,
353
        cache_config: Optional[CacheConfig] = None,
354
        quant_config: Optional[QuantizationConfig] = None,
355
        lora_config: Optional[LoRAConfig] = None,
356
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
357
        super().__init__()
358

Woosuk Kwon's avatar
Woosuk Kwon committed
359
        self.config = config
360
361
        self.lora_config = lora_config

362
363
364
365
        self.model = LlamaModel(config,
                                cache_config,
                                quant_config,
                                lora_config=lora_config)
Terry's avatar
Terry committed
366
        self.unpadded_vocab_size = config.vocab_size
367
        if lora_config:
Terry's avatar
Terry committed
368
            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
369
        self.lm_head = ParallelLMHead(
Terry's avatar
Terry committed
370
            self.unpadded_vocab_size,
371
372
373
374
375
376
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
            padding_size=DEFAULT_VOCAB_PADDING_SIZE
            # We need bigger padding if using lora for kernel
            # compatibility
            if not lora_config else lora_config.lora_vocab_padding_size,
377
            quant_config=quant_config,
378
        )
379
380
        if config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
381
382
383
384
385

        logit_scale = getattr(config, "logit_scale", 1.0)
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size, logit_scale)
        self.sampler = Sampler()
Woosuk Kwon's avatar
Woosuk Kwon committed
386
387
388

    def forward(
        self,
389
390
        input_ids: torch.Tensor,
        positions: torch.Tensor,
391
392
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
393
394
395
396
397
        intermediate_tensors: Optional[IntermediateTensors] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        model_output = self.model(input_ids, positions, kv_caches,
                                  attn_metadata, intermediate_tensors)
        return model_output
398

399
400
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
401
        logits = self.logits_processor(self.lm_head, hidden_states,
402
403
404
                                       sampling_metadata)
        return logits

405
406
    def sample(
        self,
407
        logits: torch.Tensor,
408
        sampling_metadata: SamplingMetadata,
409
    ) -> Optional[SamplerOutput]:
410
        next_tokens = self.sampler(logits, sampling_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
411
412
        return next_tokens

413
414
415
416
417
418
419
420
421
422
423
424
425
426
    def make_empty_intermediate_tensors(
            self, batch_size: int, dtype: torch.dtype,
            device: torch.device) -> IntermediateTensors:
        return IntermediateTensors({
            "hidden_states":
            torch.zeros((batch_size, self.config.hidden_size),
                        dtype=dtype,
                        device=device),
            "residual":
            torch.zeros((batch_size, self.config.hidden_size),
                        dtype=dtype,
                        device=device),
        })

427
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
428
429
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
430
431
432
433
434
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
Zhuohan Li's avatar
Zhuohan Li committed
435
        ]
436
        params_dict = dict(self.named_parameters())
437
        for name, loaded_weight in weights:
438
439
            if "rotary_emb.inv_freq" in name:
                continue
440
441
442
443
            if ("rotary_emb.cos_cached" in name
                    or "rotary_emb.sin_cached" in name):
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
444
                continue
445
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
Zhuohan Li's avatar
Zhuohan Li committed
446
                if weight_name not in name:
447
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
448
449
450
451
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
452
453
454
455

                if is_pp_missing_parameter(name, self):
                    continue

CHU Tianxiang's avatar
CHU Tianxiang committed
456
                param = params_dict[name]
457
458
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
459

460
                break
461
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
462
463
464
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
465
466
467
468
469
470
471
472
473
474
475
476
477
                # Remapping the name of FP8 kv-scale.
                if name.endswith("kv_scale"):
                    remapped_kv_scale_name = name.replace(
                        ".kv_scale", ".attn.kv_scale")
                    if remapped_kv_scale_name not in params_dict:
                        print_warning_once(
                            f"Found kv scale in the checkpoint (e.g. {name}), "
                            "but not found the expected name in the model "
                            f"(e.g. {remapped_kv_scale_name}). kv-scale is "
                            "not loaded.")
                        continue
                    else:
                        name = remapped_kv_scale_name
478
479
480
481

                if is_pp_missing_parameter(name, self):
                    continue

482
483
484
485
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
486
487
488
489
490
491
492
493
494
495
496

    # If this function is called, it should always initialize KV cache scale
    # factors (or else raise an exception). Thus, handled exceptions should
    # make sure to leave KV cache scale factors in a known good (dummy) state
    def load_kv_cache_scales(self, quantization_param_path: str) -> None:
        tp_size = get_tensor_model_parallel_world_size()
        tp_rank = get_tensor_model_parallel_rank()
        for layer_idx, scaling_factor in kv_cache_scales_loader(
                quantization_param_path, tp_rank, tp_size,
                self.config.num_hidden_layers,
                self.config.__class__.model_type):
497
498
            if not isinstance(self.model.layers[layer_idx], nn.Identity):
                layer_self_attn = self.model.layers[layer_idx].self_attn
499
500
501
502
503
504
505
506

            if is_hip():
                # The scaling factor convention we are assuming is
                # quantized_value * scaling_factor ~= true_value
                # which is consistent with the practice of setting
                # scaling_factor = tensor_amax / FPtype_max
                scaling_factor *= 2
            if hasattr(layer_self_attn, "kv_scale"):
507
                layer_self_attn.attn._kv_scale = scaling_factor
508
509
510
            else:
                raise RuntimeError("Self attention has no KV cache scaling "
                                   "factor attribute!")