llama.py 15 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
23
"""Inference-only LLaMA model compatible with HuggingFace weights."""
24
from typing import Any, Dict, List, Optional, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
25
26
27
28
29

import torch
from torch import nn
from transformers import LlamaConfig

30
from vllm.config import LoRAConfig
Woosuk Kwon's avatar
Woosuk Kwon committed
31
32
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
33
from vllm.model_executor.layers.attention import Attention
34
35
36
37
38
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
                                               MergedColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
Woosuk Kwon's avatar
Woosuk Kwon committed
39
from vllm.model_executor.layers.rotary_embedding import get_rope
40
from vllm.model_executor.layers.logits_processor import LogitsProcessor
Woosuk Kwon's avatar
Woosuk Kwon committed
41
from vllm.model_executor.layers.sampler import Sampler
42
from vllm.model_executor.layers.vocab_parallel_embedding import (
43
    VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE)
Woosuk Kwon's avatar
Woosuk Kwon committed
44
from vllm.model_executor.parallel_utils.parallel_state import (
45
    get_tensor_model_parallel_world_size)
46
from vllm.model_executor.sampling_metadata import SamplingMetadata
47
48
from vllm.model_executor.weight_utils import (default_weight_loader,
                                              hf_model_weights_iterator)
49
from vllm.sequence import SamplerOutput
Woosuk Kwon's avatar
Woosuk Kwon committed
50
51
52
53
54

KVCache = Tuple[torch.Tensor, torch.Tensor]


class LlamaMLP(nn.Module):
55

Woosuk Kwon's avatar
Woosuk Kwon committed
56
57
    def __init__(
        self,
58
59
60
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
61
        linear_method: Optional[LinearMethodBase] = None,
62
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
63
        super().__init__()
64
        self.gate_up_proj = MergedColumnParallelLinear(
65
            hidden_size, [intermediate_size] * 2,
66
67
            bias=False,
            linear_method=linear_method)
68
69
        self.down_proj = RowParallelLinear(intermediate_size,
                                           hidden_size,
70
71
                                           bias=False,
                                           linear_method=linear_method)
72
73
74
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
Woosuk Kwon's avatar
Woosuk Kwon committed
75
        self.act_fn = SiluAndMul()
Woosuk Kwon's avatar
Woosuk Kwon committed
76
77

    def forward(self, x):
78
        gate_up, _ = self.gate_up_proj(x)
Woosuk Kwon's avatar
Woosuk Kwon committed
79
        x = self.act_fn(gate_up)
Woosuk Kwon's avatar
Woosuk Kwon committed
80
81
82
83
84
85
86
87
        x, _ = self.down_proj(x)
        return x


class LlamaAttention(nn.Module):

    def __init__(
        self,
88
89
90
91
92
93
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
        max_position_embeddings: int = 8192,
94
        linear_method: Optional[LinearMethodBase] = None,
95
        bias: bool = False,
96
        sliding_window: Optional[int] = None,
97
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
98
        super().__init__()
99
        self.hidden_size = hidden_size
Zhuohan Li's avatar
Zhuohan Li committed
100
        tp_size = get_tensor_model_parallel_world_size()
101
        self.total_num_heads = num_heads
Zhuohan Li's avatar
Zhuohan Li committed
102
103
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
104
        self.total_num_kv_heads = num_kv_heads
105
106
107
108
109
110
111
112
113
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
114
        self.head_dim = hidden_size // self.total_num_heads
Zhuohan Li's avatar
Zhuohan Li committed
115
116
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
117
        self.scaling = self.head_dim**-0.5
118
119
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings
Woosuk Kwon's avatar
Woosuk Kwon committed
120

121
        self.qkv_proj = QKVParallelLinear(
122
            hidden_size,
Zhuohan Li's avatar
Zhuohan Li committed
123
            self.head_dim,
124
125
            self.total_num_heads,
            self.total_num_kv_heads,
126
            bias=bias,
127
            linear_method=linear_method,
Woosuk Kwon's avatar
Woosuk Kwon committed
128
        )
129
        self.o_proj = RowParallelLinear(
Woosuk Kwon's avatar
Woosuk Kwon committed
130
            self.total_num_heads * self.head_dim,
131
            hidden_size,
132
            bias=bias,
133
            linear_method=linear_method,
Woosuk Kwon's avatar
Woosuk Kwon committed
134
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
135

136
137
138
139
140
141
142
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
143
144
145
146
147
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
                              num_kv_heads=self.num_kv_heads,
                              sliding_window=sliding_window)
Woosuk Kwon's avatar
Woosuk Kwon committed
148
149
150

    def forward(
        self,
151
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
152
153
154
155
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
    ) -> torch.Tensor:
156
        qkv, _ = self.qkv_proj(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
157
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
158
        q, k = self.rotary_emb(positions, q, k)
159
        k_cache, v_cache = kv_cache
160
        attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
161
162
163
164
165
166
        output, _ = self.o_proj(attn_output)
        return output


class LlamaDecoderLayer(nn.Module):

167
168
169
    def __init__(
        self,
        config: LlamaConfig,
170
        linear_method: Optional[LinearMethodBase] = None,
171
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
172
173
        super().__init__()
        self.hidden_size = config.hidden_size
174
175
176
177
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
178
        sliding_window = getattr(config, "sliding_window", None)
Woosuk Kwon's avatar
Woosuk Kwon committed
179
        self.self_attn = LlamaAttention(
180
181
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
182
183
            num_kv_heads=getattr(config, "num_key_value_heads",
                                 config.num_attention_heads),
184
185
186
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
187
            linear_method=linear_method,
188
            bias=getattr(config, "bias", False),
189
            sliding_window=sliding_window,
Woosuk Kwon's avatar
Woosuk Kwon committed
190
191
        )
        self.mlp = LlamaMLP(
192
193
194
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
195
            linear_method=linear_method,
Woosuk Kwon's avatar
Woosuk Kwon committed
196
        )
197
198
199
200
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)
Woosuk Kwon's avatar
Woosuk Kwon committed
201
202
203

    def forward(
        self,
204
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
205
206
207
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
208
209
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
210
        # Self Attention
211
212
213
214
215
216
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
217
218
219
220
221
222
223
224
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
            input_metadata=input_metadata,
        )

        # Fully Connected
225
226
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
227
        hidden_states = self.mlp(hidden_states)
228
        return hidden_states, residual
Woosuk Kwon's avatar
Woosuk Kwon committed
229
230
231
232


class LlamaModel(nn.Module):

233
234
235
    def __init__(
        self,
        config: LlamaConfig,
236
        linear_method: Optional[LinearMethodBase] = None,
237
        lora_config: Optional[LoRAConfig] = None,
238
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
239
240
241
        super().__init__()
        self.config = config
        self.padding_idx = config.pad_token_id
242
243
244
245
        lora_vocab = (lora_config.lora_extra_vocab_size *
                      (lora_config.max_loras or 1)) if lora_config else 0
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size
246
        self.embed_tokens = VocabParallelEmbedding(
247
            self.vocab_size,
248
            config.hidden_size,
249
            org_num_embeddings=config.vocab_size,
250
        )
251
        self.layers = nn.ModuleList([
252
            LlamaDecoderLayer(config, linear_method)
253
            for _ in range(config.num_hidden_layers)
254
        ])
255
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Woosuk Kwon's avatar
Woosuk Kwon committed
256
257
258

    def forward(
        self,
259
260
        input_ids: torch.Tensor,
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
261
262
263
264
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
    ) -> torch.Tensor:
        hidden_states = self.embed_tokens(input_ids)
265
        residual = None
Woosuk Kwon's avatar
Woosuk Kwon committed
266
267
        for i in range(len(self.layers)):
            layer = self.layers[i]
268
            hidden_states, residual = layer(
Woosuk Kwon's avatar
Woosuk Kwon committed
269
270
271
272
                positions,
                hidden_states,
                kv_caches[i],
                input_metadata,
273
                residual,
Woosuk Kwon's avatar
Woosuk Kwon committed
274
            )
275
        hidden_states, _ = self.norm(hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
276
277
278
279
        return hidden_states


class LlamaForCausalLM(nn.Module):
Terry's avatar
Terry committed
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

    # LoRA specific attributes
    supported_lora_modules = [
        "qkv_proj",
        "o_proj",
        "gate_up_proj",
        "down_proj",
        "embed_tokens",
        "lm_head",
    ]
    embedding_modules = {
        "embed_tokens": "input_embeddings",
        "lm_head": "output_embeddings",
    }
    embedding_padding_modules = ["lm_head"]
306

307
308
309
    def __init__(
        self,
        config: LlamaConfig,
310
        linear_method: Optional[LinearMethodBase] = None,
311
        lora_config: Optional[LoRAConfig] = None,
312
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
313
314
        super().__init__()
        self.config = config
315
        self.linear_method = linear_method
316
        self.model = LlamaModel(config, linear_method, lora_config=lora_config)
Terry's avatar
Terry committed
317
        self.unpadded_vocab_size = config.vocab_size
318
        if lora_config:
Terry's avatar
Terry committed
319
            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
320
        self.lm_head = ParallelLMHead(
Terry's avatar
Terry committed
321
            self.unpadded_vocab_size,
322
323
324
325
326
327
328
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
            padding_size=DEFAULT_VOCAB_PADDING_SIZE
            # We need bigger padding if using lora for kernel
            # compatibility
            if not lora_config else lora_config.lora_vocab_padding_size,
        )
329
330
331
332
333

        logit_scale = getattr(config, "logit_scale", 1.0)
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size, logit_scale)
        self.sampler = Sampler()
Woosuk Kwon's avatar
Woosuk Kwon committed
334
335
336

    def forward(
        self,
337
338
        input_ids: torch.Tensor,
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
339
340
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
341
    ) -> torch.Tensor:
342
        hidden_states = self.model(input_ids, positions, kv_caches,
343
                                   input_metadata)
344
345
        return hidden_states

346
347
348
349
350
351
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head.weight, hidden_states,
                                       sampling_metadata)
        return logits

352
353
    def sample(
        self,
354
        logits: torch.Tensor,
355
        sampling_metadata: SamplingMetadata,
356
    ) -> Optional[SamplerOutput]:
357
        next_tokens = self.sampler(logits, sampling_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
358
359
        return next_tokens

360
361
    def load_weights(self,
                     model_name_or_path: str,
362
                     cache_dir: Optional[str] = None,
Jasmond L's avatar
Jasmond L committed
363
364
                     load_format: str = "auto",
                     revision: Optional[str] = None):
365
366
367
368
369
370
371
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
Zhuohan Li's avatar
Zhuohan Li committed
372
        ]
373
        params_dict = dict(self.named_parameters())
374
        for name, loaded_weight in hf_model_weights_iterator(
Jasmond L's avatar
Jasmond L committed
375
                model_name_or_path, cache_dir, load_format, revision):
376
377
            if "rotary_emb.inv_freq" in name:
                continue
378
379
380
381
            if ("rotary_emb.cos_cached" in name
                    or "rotary_emb.sin_cached" in name):
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
382
                continue
383
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
Zhuohan Li's avatar
Zhuohan Li committed
384
                if weight_name not in name:
385
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
386
387
388
389
390
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
391
392
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
393
                break
394
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
395
396
397
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
398
399
400
401
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)