llama.py 17 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
23
"""Inference-only LLaMA model compatible with HuggingFace weights."""
24
from typing import Any, Dict, List, Optional, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
25
26
27
28
29

import torch
from torch import nn
from transformers import LlamaConfig

30
from vllm.attention import Attention, AttentionMetadata
31
from vllm.config import LoRAConfig
Woosuk Kwon's avatar
Woosuk Kwon committed
32
from vllm.model_executor.layers.activation import SiluAndMul
33
34
35
36
37
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
                                               MergedColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
38
from vllm.model_executor.layers.logits_processor import LogitsProcessor
39
from vllm.model_executor.layers.rotary_embedding import get_rope
Woosuk Kwon's avatar
Woosuk Kwon committed
40
from vllm.model_executor.layers.sampler import Sampler
41
from vllm.model_executor.layers.vocab_parallel_embedding import (
42
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
Woosuk Kwon's avatar
Woosuk Kwon committed
43
from vllm.model_executor.parallel_utils.parallel_state import (
44
    get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
45
from vllm.model_executor.sampling_metadata import SamplingMetadata
46
from vllm.model_executor.weight_utils import (default_weight_loader,
47
48
                                              hf_model_weights_iterator,
                                              kv_cache_scales_loader)
49
from vllm.sequence import SamplerOutput
50
from vllm.utils import is_hip
Woosuk Kwon's avatar
Woosuk Kwon committed
51
52
53


class LlamaMLP(nn.Module):
54

Woosuk Kwon's avatar
Woosuk Kwon committed
55
56
    def __init__(
        self,
57
58
59
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
60
        linear_method: Optional[LinearMethodBase] = None,
61
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
62
        super().__init__()
63
        self.gate_up_proj = MergedColumnParallelLinear(
64
            hidden_size, [intermediate_size] * 2,
65
66
            bias=False,
            linear_method=linear_method)
67
68
        self.down_proj = RowParallelLinear(intermediate_size,
                                           hidden_size,
69
70
                                           bias=False,
                                           linear_method=linear_method)
71
72
73
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
Woosuk Kwon's avatar
Woosuk Kwon committed
74
        self.act_fn = SiluAndMul()
Woosuk Kwon's avatar
Woosuk Kwon committed
75
76

    def forward(self, x):
77
        gate_up, _ = self.gate_up_proj(x)
Woosuk Kwon's avatar
Woosuk Kwon committed
78
        x = self.act_fn(gate_up)
Woosuk Kwon's avatar
Woosuk Kwon committed
79
80
81
82
83
84
85
86
        x, _ = self.down_proj(x)
        return x


class LlamaAttention(nn.Module):

    def __init__(
        self,
87
88
89
90
91
92
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
        max_position_embeddings: int = 8192,
93
        linear_method: Optional[LinearMethodBase] = None,
94
        bias: bool = False,
95
        sliding_window: Optional[int] = None,
96
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
97
        super().__init__()
98
        self.hidden_size = hidden_size
Zhuohan Li's avatar
Zhuohan Li committed
99
        tp_size = get_tensor_model_parallel_world_size()
100
        self.total_num_heads = num_heads
Zhuohan Li's avatar
Zhuohan Li committed
101
102
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
103
        self.total_num_kv_heads = num_kv_heads
104
105
106
107
108
109
110
111
112
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
113
        self.head_dim = hidden_size // self.total_num_heads
Zhuohan Li's avatar
Zhuohan Li committed
114
115
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
116
        self.scaling = self.head_dim**-0.5
117
118
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings
Woosuk Kwon's avatar
Woosuk Kwon committed
119

120
121
122
123
124
125
126
127
128
        # This will be overwritten by model initialization if we are using it.
        # N.B. currently we only support per tensor scalar scaling factors
        # & only applicable to ROCm (AMD GPU).
        # The scaling factor convention we are assuming is
        # quantized_value * scaling_factor ~= true_value
        # which is consistent with the practice of setting
        # scaling_factor = tensor_amax / FPtype_max
        self.kv_scale = 1.0

129
        self.qkv_proj = QKVParallelLinear(
130
            hidden_size,
Zhuohan Li's avatar
Zhuohan Li committed
131
            self.head_dim,
132
133
            self.total_num_heads,
            self.total_num_kv_heads,
134
            bias=bias,
135
            linear_method=linear_method,
Woosuk Kwon's avatar
Woosuk Kwon committed
136
        )
137
        self.o_proj = RowParallelLinear(
Woosuk Kwon's avatar
Woosuk Kwon committed
138
            self.total_num_heads * self.head_dim,
139
            hidden_size,
140
            bias=bias,
141
            linear_method=linear_method,
Woosuk Kwon's avatar
Woosuk Kwon committed
142
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
143

144
145
146
147
148
149
150
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
151
152
153
154
155
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
                              num_kv_heads=self.num_kv_heads,
                              sliding_window=sliding_window)
Woosuk Kwon's avatar
Woosuk Kwon committed
156
157
158

    def forward(
        self,
159
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
160
        hidden_states: torch.Tensor,
161
162
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
163
    ) -> torch.Tensor:
164
        qkv, _ = self.qkv_proj(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
165
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
166
        q, k = self.rotary_emb(positions, q, k)
167
168
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata,
                                self.kv_scale)
Woosuk Kwon's avatar
Woosuk Kwon committed
169
170
171
172
173
174
        output, _ = self.o_proj(attn_output)
        return output


class LlamaDecoderLayer(nn.Module):

175
176
177
    def __init__(
        self,
        config: LlamaConfig,
178
        linear_method: Optional[LinearMethodBase] = None,
179
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
180
181
        super().__init__()
        self.hidden_size = config.hidden_size
182
183
184
185
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
186
        sliding_window = getattr(config, "sliding_window", None)
Woosuk Kwon's avatar
Woosuk Kwon committed
187
        self.self_attn = LlamaAttention(
188
189
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
190
191
            num_kv_heads=getattr(config, "num_key_value_heads",
                                 config.num_attention_heads),
192
193
194
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
195
            linear_method=linear_method,
196
            bias=getattr(config, "bias", False),
197
            sliding_window=sliding_window,
Woosuk Kwon's avatar
Woosuk Kwon committed
198
199
        )
        self.mlp = LlamaMLP(
200
201
202
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
203
            linear_method=linear_method,
Woosuk Kwon's avatar
Woosuk Kwon committed
204
        )
205
206
207
208
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)
Woosuk Kwon's avatar
Woosuk Kwon committed
209
210
211

    def forward(
        self,
212
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
213
        hidden_states: torch.Tensor,
214
215
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
216
217
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
218
        # Self Attention
219
220
221
222
223
224
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
225
226
227
228
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
229
            attn_metadata=attn_metadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
230
231
232
        )

        # Fully Connected
233
234
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
235
        hidden_states = self.mlp(hidden_states)
236
        return hidden_states, residual
Woosuk Kwon's avatar
Woosuk Kwon committed
237
238
239
240


class LlamaModel(nn.Module):

241
242
243
    def __init__(
        self,
        config: LlamaConfig,
244
        linear_method: Optional[LinearMethodBase] = None,
245
        lora_config: Optional[LoRAConfig] = None,
246
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
247
248
249
        super().__init__()
        self.config = config
        self.padding_idx = config.pad_token_id
250
251
252
253
        lora_vocab = (lora_config.lora_extra_vocab_size *
                      (lora_config.max_loras or 1)) if lora_config else 0
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size
254
        self.embed_tokens = VocabParallelEmbedding(
255
            self.vocab_size,
256
            config.hidden_size,
257
            org_num_embeddings=config.vocab_size,
258
        )
259
        self.layers = nn.ModuleList([
260
            LlamaDecoderLayer(config, linear_method)
261
            for _ in range(config.num_hidden_layers)
262
        ])
263
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Woosuk Kwon's avatar
Woosuk Kwon committed
264

265
266
267
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

Woosuk Kwon's avatar
Woosuk Kwon committed
268
269
    def forward(
        self,
270
        input_ids: Optional[torch.Tensor],
271
        positions: torch.Tensor,
272
273
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
274
        inputs_embeds: Optional[torch.Tensor] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
275
    ) -> torch.Tensor:
276
277
278
279
        if inputs_embeds is not None:
            hidden_states = inputs_embeds
        else:
            hidden_states = self.get_input_embeddings(input_ids)
280
        residual = None
Woosuk Kwon's avatar
Woosuk Kwon committed
281
282
        for i in range(len(self.layers)):
            layer = self.layers[i]
283
            hidden_states, residual = layer(
Woosuk Kwon's avatar
Woosuk Kwon committed
284
285
286
                positions,
                hidden_states,
                kv_caches[i],
287
                attn_metadata,
288
                residual,
Woosuk Kwon's avatar
Woosuk Kwon committed
289
            )
290
        hidden_states, _ = self.norm(hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
291
292
293
294
        return hidden_states


class LlamaForCausalLM(nn.Module):
Terry's avatar
Terry committed
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

    # LoRA specific attributes
    supported_lora_modules = [
        "qkv_proj",
        "o_proj",
        "gate_up_proj",
        "down_proj",
        "embed_tokens",
        "lm_head",
    ]
    embedding_modules = {
        "embed_tokens": "input_embeddings",
        "lm_head": "output_embeddings",
    }
    embedding_padding_modules = ["lm_head"]
321

322
323
324
    def __init__(
        self,
        config: LlamaConfig,
325
        linear_method: Optional[LinearMethodBase] = None,
326
        lora_config: Optional[LoRAConfig] = None,
327
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
328
329
        super().__init__()
        self.config = config
330
        self.linear_method = linear_method
331
        self.model = LlamaModel(config, linear_method, lora_config=lora_config)
Terry's avatar
Terry committed
332
        self.unpadded_vocab_size = config.vocab_size
333
        if lora_config:
Terry's avatar
Terry committed
334
            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
335
        self.lm_head = ParallelLMHead(
Terry's avatar
Terry committed
336
            self.unpadded_vocab_size,
337
338
339
340
341
342
343
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
            padding_size=DEFAULT_VOCAB_PADDING_SIZE
            # We need bigger padding if using lora for kernel
            # compatibility
            if not lora_config else lora_config.lora_vocab_padding_size,
        )
344
345
346
347
348

        logit_scale = getattr(config, "logit_scale", 1.0)
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size, logit_scale)
        self.sampler = Sampler()
Woosuk Kwon's avatar
Woosuk Kwon committed
349
350
351

    def forward(
        self,
352
353
        input_ids: torch.Tensor,
        positions: torch.Tensor,
354
355
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
356
    ) -> torch.Tensor:
357
        hidden_states = self.model(input_ids, positions, kv_caches,
358
                                   attn_metadata)
359
360
        return hidden_states

361
362
363
364
365
366
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head.weight, hidden_states,
                                       sampling_metadata)
        return logits

367
368
    def sample(
        self,
369
        logits: torch.Tensor,
370
        sampling_metadata: SamplingMetadata,
371
    ) -> Optional[SamplerOutput]:
372
        next_tokens = self.sampler(logits, sampling_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
373
374
        return next_tokens

375
376
    def load_weights(self,
                     model_name_or_path: str,
377
                     cache_dir: Optional[str] = None,
Jasmond L's avatar
Jasmond L committed
378
379
                     load_format: str = "auto",
                     revision: Optional[str] = None):
380
381
382
383
384
385
386
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
Zhuohan Li's avatar
Zhuohan Li committed
387
        ]
388
        params_dict = dict(self.named_parameters())
389
        for name, loaded_weight in hf_model_weights_iterator(
Jasmond L's avatar
Jasmond L committed
390
                model_name_or_path, cache_dir, load_format, revision):
391
392
            if "rotary_emb.inv_freq" in name:
                continue
393
394
395
396
            if ("rotary_emb.cos_cached" in name
                    or "rotary_emb.sin_cached" in name):
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
397
                continue
398
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
Zhuohan Li's avatar
Zhuohan Li committed
399
                if weight_name not in name:
400
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
401
402
403
404
405
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
406
407
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
408
                break
409
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
410
411
412
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
413
414
415
416
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440

    # If this function is called, it should always initialize KV cache scale
    # factors (or else raise an exception). Thus, handled exceptions should
    # make sure to leave KV cache scale factors in a known good (dummy) state
    def load_kv_cache_scales(self, quantization_param_path: str) -> None:
        tp_size = get_tensor_model_parallel_world_size()
        tp_rank = get_tensor_model_parallel_rank()
        for layer_idx, scaling_factor in kv_cache_scales_loader(
                quantization_param_path, tp_rank, tp_size,
                self.config.num_hidden_layers,
                self.config.__class__.model_type):
            layer_self_attn = self.model.layers[layer_idx].self_attn

            if is_hip():
                # The scaling factor convention we are assuming is
                # quantized_value * scaling_factor ~= true_value
                # which is consistent with the practice of setting
                # scaling_factor = tensor_amax / FPtype_max
                scaling_factor *= 2
            if hasattr(layer_self_attn, "kv_scale"):
                layer_self_attn.kv_scale = scaling_factor
            else:
                raise RuntimeError("Self attention has no KV cache scaling "
                                   "factor attribute!")