llama.py 14.2 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
23
"""Inference-only LLaMA model compatible with HuggingFace weights."""
24
from typing import Any, Dict, List, Optional, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
25
26
27
28
29

import torch
from torch import nn
from transformers import LlamaConfig

Woosuk Kwon's avatar
Woosuk Kwon committed
30
31
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
Woosuk Kwon's avatar
Woosuk Kwon committed
32
from vllm.model_executor.layers.attention import PagedAttention
33
34
35
36
37
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
                                               MergedColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
Woosuk Kwon's avatar
Woosuk Kwon committed
38
from vllm.model_executor.layers.rotary_embedding import get_rope
Woosuk Kwon's avatar
Woosuk Kwon committed
39
from vllm.model_executor.layers.sampler import Sampler
40
from vllm.model_executor.layers.vocab_parallel_embedding import (
41
    VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE)
Woosuk Kwon's avatar
Woosuk Kwon committed
42
from vllm.model_executor.parallel_utils.parallel_state import (
43
    get_tensor_model_parallel_world_size)
44
from vllm.model_executor.sampling_metadata import SamplingMetadata
45
46
from vllm.model_executor.weight_utils import (default_weight_loader,
                                              hf_model_weights_iterator)
47
from vllm.sequence import SamplerOutput
48
from vllm.config import LoRAConfig
Woosuk Kwon's avatar
Woosuk Kwon committed
49
50
51
52
53

KVCache = Tuple[torch.Tensor, torch.Tensor]


class LlamaMLP(nn.Module):
54

Woosuk Kwon's avatar
Woosuk Kwon committed
55
56
    def __init__(
        self,
57
58
59
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
60
        linear_method: Optional[LinearMethodBase] = None,
61
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
62
        super().__init__()
63
        self.gate_up_proj = MergedColumnParallelLinear(
64
            hidden_size, [intermediate_size] * 2,
65
66
            bias=False,
            linear_method=linear_method)
67
68
        self.down_proj = RowParallelLinear(intermediate_size,
                                           hidden_size,
69
70
                                           bias=False,
                                           linear_method=linear_method)
71
72
73
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
Woosuk Kwon's avatar
Woosuk Kwon committed
74
        self.act_fn = SiluAndMul()
Woosuk Kwon's avatar
Woosuk Kwon committed
75
76

    def forward(self, x):
77
        gate_up, _ = self.gate_up_proj(x)
Woosuk Kwon's avatar
Woosuk Kwon committed
78
        x = self.act_fn(gate_up)
Woosuk Kwon's avatar
Woosuk Kwon committed
79
80
81
82
83
84
85
86
        x, _ = self.down_proj(x)
        return x


class LlamaAttention(nn.Module):

    def __init__(
        self,
87
88
89
90
91
92
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
        max_position_embeddings: int = 8192,
93
        linear_method: Optional[LinearMethodBase] = None,
94
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
95
        super().__init__()
96
        self.hidden_size = hidden_size
Zhuohan Li's avatar
Zhuohan Li committed
97
        tp_size = get_tensor_model_parallel_world_size()
98
        self.total_num_heads = num_heads
Zhuohan Li's avatar
Zhuohan Li committed
99
100
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
101
        self.total_num_kv_heads = num_kv_heads
102
103
104
105
106
107
108
109
110
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
111
        self.head_dim = hidden_size // self.total_num_heads
Zhuohan Li's avatar
Zhuohan Li committed
112
113
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
114
        self.scaling = self.head_dim**-0.5
115
116
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings
Woosuk Kwon's avatar
Woosuk Kwon committed
117

118
        self.qkv_proj = QKVParallelLinear(
119
            hidden_size,
Zhuohan Li's avatar
Zhuohan Li committed
120
            self.head_dim,
121
122
            self.total_num_heads,
            self.total_num_kv_heads,
123
            bias=False,
124
            linear_method=linear_method,
Woosuk Kwon's avatar
Woosuk Kwon committed
125
        )
126
        self.o_proj = RowParallelLinear(
Woosuk Kwon's avatar
Woosuk Kwon committed
127
            self.total_num_heads * self.head_dim,
128
129
            hidden_size,
            bias=False,
130
            linear_method=linear_method,
Woosuk Kwon's avatar
Woosuk Kwon committed
131
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
132

133
134
135
136
137
138
139
140
141
142
143
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
        self.attn = PagedAttention(self.num_heads,
                                   self.head_dim,
                                   self.scaling,
                                   num_kv_heads=self.num_kv_heads)
Woosuk Kwon's avatar
Woosuk Kwon committed
144
145
146

    def forward(
        self,
147
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
148
149
150
151
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
    ) -> torch.Tensor:
152
        qkv, _ = self.qkv_proj(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
153
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
154
        q, k = self.rotary_emb(positions, q, k)
155
        k_cache, v_cache = kv_cache
156
        attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
157
158
159
160
161
162
        output, _ = self.o_proj(attn_output)
        return output


class LlamaDecoderLayer(nn.Module):

163
164
165
    def __init__(
        self,
        config: LlamaConfig,
166
        linear_method: Optional[LinearMethodBase] = None,
167
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
168
169
        super().__init__()
        self.hidden_size = config.hidden_size
170
171
172
173
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
Woosuk Kwon's avatar
Woosuk Kwon committed
174
        self.self_attn = LlamaAttention(
175
176
177
178
179
180
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
181
            linear_method=linear_method,
Woosuk Kwon's avatar
Woosuk Kwon committed
182
183
        )
        self.mlp = LlamaMLP(
184
185
186
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
187
            linear_method=linear_method,
Woosuk Kwon's avatar
Woosuk Kwon committed
188
        )
189
190
191
192
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)
Woosuk Kwon's avatar
Woosuk Kwon committed
193
194
195

    def forward(
        self,
196
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
197
198
199
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
200
201
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
202
        # Self Attention
203
204
205
206
207
208
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
209
210
211
212
213
214
215
216
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
            input_metadata=input_metadata,
        )

        # Fully Connected
217
218
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
219
        hidden_states = self.mlp(hidden_states)
220
        return hidden_states, residual
Woosuk Kwon's avatar
Woosuk Kwon committed
221
222
223
224


class LlamaModel(nn.Module):

225
226
227
    def __init__(
        self,
        config: LlamaConfig,
228
        linear_method: Optional[LinearMethodBase] = None,
229
        lora_config: Optional[LoRAConfig] = None,
230
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
231
232
233
        super().__init__()
        self.config = config
        self.padding_idx = config.pad_token_id
234
235
236
237
        lora_vocab = (lora_config.lora_extra_vocab_size *
                      (lora_config.max_loras or 1)) if lora_config else 0
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size
238
        self.embed_tokens = VocabParallelEmbedding(
239
            self.vocab_size,
240
            config.hidden_size,
241
            org_num_embeddings=config.vocab_size,
242
        )
243
        self.layers = nn.ModuleList([
244
            LlamaDecoderLayer(config, linear_method)
245
            for _ in range(config.num_hidden_layers)
246
        ])
247
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Woosuk Kwon's avatar
Woosuk Kwon committed
248
249
250

    def forward(
        self,
251
252
        input_ids: torch.Tensor,
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
253
254
255
256
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
    ) -> torch.Tensor:
        hidden_states = self.embed_tokens(input_ids)
257
        residual = None
Woosuk Kwon's avatar
Woosuk Kwon committed
258
259
        for i in range(len(self.layers)):
            layer = self.layers[i]
260
            hidden_states, residual = layer(
Woosuk Kwon's avatar
Woosuk Kwon committed
261
262
263
264
                positions,
                hidden_states,
                kv_caches[i],
                input_metadata,
265
                residual,
Woosuk Kwon's avatar
Woosuk Kwon committed
266
            )
267
        hidden_states, _ = self.norm(hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
268
269
270
271
        return hidden_states


class LlamaForCausalLM(nn.Module):
Terry's avatar
Terry committed
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

    # LoRA specific attributes
    supported_lora_modules = [
        "qkv_proj",
        "o_proj",
        "gate_up_proj",
        "down_proj",
        "embed_tokens",
        "lm_head",
    ]
    embedding_modules = {
        "embed_tokens": "input_embeddings",
        "lm_head": "output_embeddings",
    }
    embedding_padding_modules = ["lm_head"]
298

299
300
301
    def __init__(
        self,
        config: LlamaConfig,
302
        linear_method: Optional[LinearMethodBase] = None,
303
        lora_config: Optional[LoRAConfig] = None,
304
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
305
306
        super().__init__()
        self.config = config
307
        self.linear_method = linear_method
308
        self.model = LlamaModel(config, linear_method, lora_config=lora_config)
Terry's avatar
Terry committed
309
        self.unpadded_vocab_size = config.vocab_size
310
        if lora_config:
Terry's avatar
Terry committed
311
            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
312
        self.lm_head = ParallelLMHead(
Terry's avatar
Terry committed
313
            self.unpadded_vocab_size,
314
315
316
317
318
319
320
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
            padding_size=DEFAULT_VOCAB_PADDING_SIZE
            # We need bigger padding if using lora for kernel
            # compatibility
            if not lora_config else lora_config.lora_vocab_padding_size,
        )
Terry's avatar
Terry committed
321
        self.sampler = Sampler(self.unpadded_vocab_size, config.vocab_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
322
323
324

    def forward(
        self,
325
326
        input_ids: torch.Tensor,
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
327
328
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
329
    ) -> torch.Tensor:
330
        hidden_states = self.model(input_ids, positions, kv_caches,
331
                                   input_metadata)
332
333
334
335
336
337
        return hidden_states

    def sample(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
338
    ) -> Optional[SamplerOutput]:
339
        next_tokens = self.sampler(self.lm_head.weight, hidden_states,
340
                                   sampling_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
341
342
        return next_tokens

343
344
    def load_weights(self,
                     model_name_or_path: str,
345
                     cache_dir: Optional[str] = None,
Jasmond L's avatar
Jasmond L committed
346
347
                     load_format: str = "auto",
                     revision: Optional[str] = None):
348
349
350
351
352
353
354
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
Zhuohan Li's avatar
Zhuohan Li committed
355
        ]
356
        params_dict = dict(self.named_parameters())
357
        for name, loaded_weight in hf_model_weights_iterator(
Jasmond L's avatar
Jasmond L committed
358
                model_name_or_path, cache_dir, load_format, revision):
359
360
            if "rotary_emb.inv_freq" in name:
                continue
361
362
363
364
            if ("rotary_emb.cos_cached" in name
                    or "rotary_emb.sin_cached" in name):
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
365
                continue
366
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
Zhuohan Li's avatar
Zhuohan Li committed
367
                if weight_name not in name:
368
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
369
370
371
372
373
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
374
375
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
376
                break
377
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
378
379
380
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
381
382
383
384
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)