llama.py 14.6 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
23
"""Inference-only LLaMA model compatible with HuggingFace weights."""
24
from typing import Any, Dict, List, Optional, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
25
26
27
28
29

import torch
from torch import nn
from transformers import LlamaConfig

30
from vllm.config import LoRAConfig
Woosuk Kwon's avatar
Woosuk Kwon committed
31
32
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
Woosuk Kwon's avatar
Woosuk Kwon committed
33
from vllm.model_executor.layers.attention import PagedAttention
34
35
36
37
38
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
                                               MergedColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
Woosuk Kwon's avatar
Woosuk Kwon committed
39
from vllm.model_executor.layers.rotary_embedding import get_rope
Woosuk Kwon's avatar
Woosuk Kwon committed
40
from vllm.model_executor.layers.sampler import Sampler
41
from vllm.model_executor.layers.vocab_parallel_embedding import (
42
    VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE)
Woosuk Kwon's avatar
Woosuk Kwon committed
43
from vllm.model_executor.parallel_utils.parallel_state import (
44
    get_tensor_model_parallel_world_size)
45
from vllm.model_executor.sampling_metadata import SamplingMetadata
46
47
from vllm.model_executor.weight_utils import (default_weight_loader,
                                              hf_model_weights_iterator)
48
from vllm.sequence import SamplerOutput
Woosuk Kwon's avatar
Woosuk Kwon committed
49
50
51
52
53

KVCache = Tuple[torch.Tensor, torch.Tensor]


class LlamaMLP(nn.Module):
54

Woosuk Kwon's avatar
Woosuk Kwon committed
55
56
    def __init__(
        self,
57
58
59
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
60
        linear_method: Optional[LinearMethodBase] = None,
61
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
62
        super().__init__()
63
        self.gate_up_proj = MergedColumnParallelLinear(
64
            hidden_size, [intermediate_size] * 2,
65
66
            bias=False,
            linear_method=linear_method)
67
68
        self.down_proj = RowParallelLinear(intermediate_size,
                                           hidden_size,
69
70
                                           bias=False,
                                           linear_method=linear_method)
71
72
73
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
Woosuk Kwon's avatar
Woosuk Kwon committed
74
        self.act_fn = SiluAndMul()
Woosuk Kwon's avatar
Woosuk Kwon committed
75
76

    def forward(self, x):
77
        gate_up, _ = self.gate_up_proj(x)
Woosuk Kwon's avatar
Woosuk Kwon committed
78
        x = self.act_fn(gate_up)
Woosuk Kwon's avatar
Woosuk Kwon committed
79
80
81
82
83
84
85
86
        x, _ = self.down_proj(x)
        return x


class LlamaAttention(nn.Module):

    def __init__(
        self,
87
88
89
90
91
92
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
        max_position_embeddings: int = 8192,
93
        linear_method: Optional[LinearMethodBase] = None,
94
        bias: bool = False,
95
        sliding_window: Optional[int] = None,
96
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
97
        super().__init__()
98
        self.hidden_size = hidden_size
Zhuohan Li's avatar
Zhuohan Li committed
99
        tp_size = get_tensor_model_parallel_world_size()
100
        self.total_num_heads = num_heads
Zhuohan Li's avatar
Zhuohan Li committed
101
102
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
103
        self.total_num_kv_heads = num_kv_heads
104
105
106
107
108
109
110
111
112
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
113
        self.head_dim = hidden_size // self.total_num_heads
Zhuohan Li's avatar
Zhuohan Li committed
114
115
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
116
        self.scaling = self.head_dim**-0.5
117
118
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings
Woosuk Kwon's avatar
Woosuk Kwon committed
119

120
        self.qkv_proj = QKVParallelLinear(
121
            hidden_size,
Zhuohan Li's avatar
Zhuohan Li committed
122
            self.head_dim,
123
124
            self.total_num_heads,
            self.total_num_kv_heads,
125
            bias=bias,
126
            linear_method=linear_method,
Woosuk Kwon's avatar
Woosuk Kwon committed
127
        )
128
        self.o_proj = RowParallelLinear(
Woosuk Kwon's avatar
Woosuk Kwon committed
129
            self.total_num_heads * self.head_dim,
130
            hidden_size,
131
            bias=bias,
132
            linear_method=linear_method,
Woosuk Kwon's avatar
Woosuk Kwon committed
133
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
134

135
136
137
138
139
140
141
142
143
144
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
        self.attn = PagedAttention(self.num_heads,
                                   self.head_dim,
                                   self.scaling,
145
146
                                   num_kv_heads=self.num_kv_heads,
                                   sliding_window=sliding_window)
Woosuk Kwon's avatar
Woosuk Kwon committed
147
148
149

    def forward(
        self,
150
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
151
152
153
154
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
    ) -> torch.Tensor:
155
        qkv, _ = self.qkv_proj(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
156
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
157
        q, k = self.rotary_emb(positions, q, k)
158
        k_cache, v_cache = kv_cache
159
        attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
160
161
162
163
164
165
        output, _ = self.o_proj(attn_output)
        return output


class LlamaDecoderLayer(nn.Module):

166
167
168
    def __init__(
        self,
        config: LlamaConfig,
169
        linear_method: Optional[LinearMethodBase] = None,
170
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
171
172
        super().__init__()
        self.hidden_size = config.hidden_size
173
174
175
176
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
177
        sliding_window = getattr(config, "sliding_window", None)
Woosuk Kwon's avatar
Woosuk Kwon committed
178
        self.self_attn = LlamaAttention(
179
180
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
181
182
            num_kv_heads=getattr(config, "num_key_value_heads",
                                 config.num_attention_heads),
183
184
185
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
186
            linear_method=linear_method,
187
            bias=getattr(config, "bias", False),
188
            sliding_window=sliding_window,
Woosuk Kwon's avatar
Woosuk Kwon committed
189
190
        )
        self.mlp = LlamaMLP(
191
192
193
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
194
            linear_method=linear_method,
Woosuk Kwon's avatar
Woosuk Kwon committed
195
        )
196
197
198
199
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)
Woosuk Kwon's avatar
Woosuk Kwon committed
200
201
202

    def forward(
        self,
203
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
204
205
206
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
207
208
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
209
        # Self Attention
210
211
212
213
214
215
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
216
217
218
219
220
221
222
223
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
            input_metadata=input_metadata,
        )

        # Fully Connected
224
225
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
226
        hidden_states = self.mlp(hidden_states)
227
        return hidden_states, residual
Woosuk Kwon's avatar
Woosuk Kwon committed
228
229
230
231


class LlamaModel(nn.Module):

232
233
234
    def __init__(
        self,
        config: LlamaConfig,
235
        linear_method: Optional[LinearMethodBase] = None,
236
        lora_config: Optional[LoRAConfig] = None,
237
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
238
239
240
        super().__init__()
        self.config = config
        self.padding_idx = config.pad_token_id
241
242
243
244
        lora_vocab = (lora_config.lora_extra_vocab_size *
                      (lora_config.max_loras or 1)) if lora_config else 0
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size
245
        self.embed_tokens = VocabParallelEmbedding(
246
            self.vocab_size,
247
            config.hidden_size,
248
            org_num_embeddings=config.vocab_size,
249
        )
250
        self.layers = nn.ModuleList([
251
            LlamaDecoderLayer(config, linear_method)
252
            for _ in range(config.num_hidden_layers)
253
        ])
254
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Woosuk Kwon's avatar
Woosuk Kwon committed
255
256
257

    def forward(
        self,
258
259
        input_ids: torch.Tensor,
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
260
261
262
263
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
    ) -> torch.Tensor:
        hidden_states = self.embed_tokens(input_ids)
264
        residual = None
Woosuk Kwon's avatar
Woosuk Kwon committed
265
266
        for i in range(len(self.layers)):
            layer = self.layers[i]
267
            hidden_states, residual = layer(
Woosuk Kwon's avatar
Woosuk Kwon committed
268
269
270
271
                positions,
                hidden_states,
                kv_caches[i],
                input_metadata,
272
                residual,
Woosuk Kwon's avatar
Woosuk Kwon committed
273
            )
274
        hidden_states, _ = self.norm(hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
275
276
277
278
        return hidden_states


class LlamaForCausalLM(nn.Module):
Terry's avatar
Terry committed
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

    # LoRA specific attributes
    supported_lora_modules = [
        "qkv_proj",
        "o_proj",
        "gate_up_proj",
        "down_proj",
        "embed_tokens",
        "lm_head",
    ]
    embedding_modules = {
        "embed_tokens": "input_embeddings",
        "lm_head": "output_embeddings",
    }
    embedding_padding_modules = ["lm_head"]
305

306
307
308
    def __init__(
        self,
        config: LlamaConfig,
309
        linear_method: Optional[LinearMethodBase] = None,
310
        lora_config: Optional[LoRAConfig] = None,
311
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
312
313
        super().__init__()
        self.config = config
314
        self.linear_method = linear_method
315
        self.model = LlamaModel(config, linear_method, lora_config=lora_config)
Terry's avatar
Terry committed
316
        self.unpadded_vocab_size = config.vocab_size
317
        if lora_config:
Terry's avatar
Terry committed
318
            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
319
        self.lm_head = ParallelLMHead(
Terry's avatar
Terry committed
320
            self.unpadded_vocab_size,
321
322
323
324
325
326
327
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
            padding_size=DEFAULT_VOCAB_PADDING_SIZE
            # We need bigger padding if using lora for kernel
            # compatibility
            if not lora_config else lora_config.lora_vocab_padding_size,
        )
Terry's avatar
Terry committed
328
        self.sampler = Sampler(self.unpadded_vocab_size, config.vocab_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
329
330
331

    def forward(
        self,
332
333
        input_ids: torch.Tensor,
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
334
335
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
336
    ) -> torch.Tensor:
337
        hidden_states = self.model(input_ids, positions, kv_caches,
338
                                   input_metadata)
339
340
341
342
343
344
        return hidden_states

    def sample(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
345
    ) -> Optional[SamplerOutput]:
346
        next_tokens = self.sampler(self.lm_head.weight, hidden_states,
347
                                   sampling_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
348
349
        return next_tokens

350
351
    def load_weights(self,
                     model_name_or_path: str,
352
                     cache_dir: Optional[str] = None,
Jasmond L's avatar
Jasmond L committed
353
354
                     load_format: str = "auto",
                     revision: Optional[str] = None):
355
356
357
358
359
360
361
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
Zhuohan Li's avatar
Zhuohan Li committed
362
        ]
363
        params_dict = dict(self.named_parameters())
364
        for name, loaded_weight in hf_model_weights_iterator(
Jasmond L's avatar
Jasmond L committed
365
                model_name_or_path, cache_dir, load_format, revision):
366
367
            if "rotary_emb.inv_freq" in name:
                continue
368
369
370
371
            if ("rotary_emb.cos_cached" in name
                    or "rotary_emb.sin_cached" in name):
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
372
                continue
373
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
Zhuohan Li's avatar
Zhuohan Li committed
374
                if weight_name not in name:
375
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
376
377
378
379
380
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
381
382
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
383
                break
384
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
385
386
387
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
388
389
390
391
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)