llama.py 15.1 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
23
"""Inference-only LLaMA model compatible with HuggingFace weights."""
24
from typing import Any, Dict, List, Optional, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
25
26
27
28
29

import torch
from torch import nn
from transformers import LlamaConfig

30
from vllm.attention import Attention, AttentionMetadata
31
from vllm.config import LoRAConfig
Woosuk Kwon's avatar
Woosuk Kwon committed
32
from vllm.model_executor.layers.activation import SiluAndMul
33
34
35
36
37
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
                                               MergedColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
38
from vllm.model_executor.layers.logits_processor import LogitsProcessor
39
from vllm.model_executor.layers.rotary_embedding import get_rope
Woosuk Kwon's avatar
Woosuk Kwon committed
40
from vllm.model_executor.layers.sampler import Sampler
41
from vllm.model_executor.layers.vocab_parallel_embedding import (
42
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
Woosuk Kwon's avatar
Woosuk Kwon committed
43
from vllm.model_executor.parallel_utils.parallel_state import (
44
    get_tensor_model_parallel_world_size)
45
from vllm.model_executor.sampling_metadata import SamplingMetadata
46
47
from vllm.model_executor.weight_utils import (default_weight_loader,
                                              hf_model_weights_iterator)
48
from vllm.sequence import SamplerOutput
Woosuk Kwon's avatar
Woosuk Kwon committed
49
50
51


class LlamaMLP(nn.Module):
52

Woosuk Kwon's avatar
Woosuk Kwon committed
53
54
    def __init__(
        self,
55
56
57
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
58
        linear_method: Optional[LinearMethodBase] = None,
59
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
60
        super().__init__()
61
        self.gate_up_proj = MergedColumnParallelLinear(
62
            hidden_size, [intermediate_size] * 2,
63
64
            bias=False,
            linear_method=linear_method)
65
66
        self.down_proj = RowParallelLinear(intermediate_size,
                                           hidden_size,
67
68
                                           bias=False,
                                           linear_method=linear_method)
69
70
71
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
Woosuk Kwon's avatar
Woosuk Kwon committed
72
        self.act_fn = SiluAndMul()
Woosuk Kwon's avatar
Woosuk Kwon committed
73
74

    def forward(self, x):
75
        gate_up, _ = self.gate_up_proj(x)
Woosuk Kwon's avatar
Woosuk Kwon committed
76
        x = self.act_fn(gate_up)
Woosuk Kwon's avatar
Woosuk Kwon committed
77
78
79
80
81
82
83
84
        x, _ = self.down_proj(x)
        return x


class LlamaAttention(nn.Module):

    def __init__(
        self,
85
86
87
88
89
90
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
        max_position_embeddings: int = 8192,
91
        linear_method: Optional[LinearMethodBase] = None,
92
        bias: bool = False,
93
        sliding_window: Optional[int] = None,
94
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
95
        super().__init__()
96
        self.hidden_size = hidden_size
Zhuohan Li's avatar
Zhuohan Li committed
97
        tp_size = get_tensor_model_parallel_world_size()
98
        self.total_num_heads = num_heads
Zhuohan Li's avatar
Zhuohan Li committed
99
100
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
101
        self.total_num_kv_heads = num_kv_heads
102
103
104
105
106
107
108
109
110
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
111
        self.head_dim = hidden_size // self.total_num_heads
Zhuohan Li's avatar
Zhuohan Li committed
112
113
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
114
        self.scaling = self.head_dim**-0.5
115
116
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings
Woosuk Kwon's avatar
Woosuk Kwon committed
117

118
        self.qkv_proj = QKVParallelLinear(
119
            hidden_size,
Zhuohan Li's avatar
Zhuohan Li committed
120
            self.head_dim,
121
122
            self.total_num_heads,
            self.total_num_kv_heads,
123
            bias=bias,
124
            linear_method=linear_method,
Woosuk Kwon's avatar
Woosuk Kwon committed
125
        )
126
        self.o_proj = RowParallelLinear(
Woosuk Kwon's avatar
Woosuk Kwon committed
127
            self.total_num_heads * self.head_dim,
128
            hidden_size,
129
            bias=bias,
130
            linear_method=linear_method,
Woosuk Kwon's avatar
Woosuk Kwon committed
131
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
132

133
134
135
136
137
138
139
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
140
141
142
143
144
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
                              num_kv_heads=self.num_kv_heads,
                              sliding_window=sliding_window)
Woosuk Kwon's avatar
Woosuk Kwon committed
145
146
147

    def forward(
        self,
148
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
149
        hidden_states: torch.Tensor,
150
151
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
152
    ) -> torch.Tensor:
153
        qkv, _ = self.qkv_proj(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
154
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
155
        q, k = self.rotary_emb(positions, q, k)
156
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
157
158
159
160
161
162
        output, _ = self.o_proj(attn_output)
        return output


class LlamaDecoderLayer(nn.Module):

163
164
165
    def __init__(
        self,
        config: LlamaConfig,
166
        linear_method: Optional[LinearMethodBase] = None,
167
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
168
169
        super().__init__()
        self.hidden_size = config.hidden_size
170
171
172
173
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
174
        sliding_window = getattr(config, "sliding_window", None)
Woosuk Kwon's avatar
Woosuk Kwon committed
175
        self.self_attn = LlamaAttention(
176
177
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
178
179
            num_kv_heads=getattr(config, "num_key_value_heads",
                                 config.num_attention_heads),
180
181
182
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
183
            linear_method=linear_method,
184
            bias=getattr(config, "bias", False),
185
            sliding_window=sliding_window,
Woosuk Kwon's avatar
Woosuk Kwon committed
186
187
        )
        self.mlp = LlamaMLP(
188
189
190
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
191
            linear_method=linear_method,
Woosuk Kwon's avatar
Woosuk Kwon committed
192
        )
193
194
195
196
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)
Woosuk Kwon's avatar
Woosuk Kwon committed
197
198
199

    def forward(
        self,
200
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
201
        hidden_states: torch.Tensor,
202
203
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
204
205
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
206
        # Self Attention
207
208
209
210
211
212
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
213
214
215
216
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
217
            attn_metadata=attn_metadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
218
219
220
        )

        # Fully Connected
221
222
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
223
        hidden_states = self.mlp(hidden_states)
224
        return hidden_states, residual
Woosuk Kwon's avatar
Woosuk Kwon committed
225
226
227
228


class LlamaModel(nn.Module):

229
230
231
    def __init__(
        self,
        config: LlamaConfig,
232
        linear_method: Optional[LinearMethodBase] = None,
233
        lora_config: Optional[LoRAConfig] = None,
234
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
235
236
237
        super().__init__()
        self.config = config
        self.padding_idx = config.pad_token_id
238
239
240
241
        lora_vocab = (lora_config.lora_extra_vocab_size *
                      (lora_config.max_loras or 1)) if lora_config else 0
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size
242
        self.embed_tokens = VocabParallelEmbedding(
243
            self.vocab_size,
244
            config.hidden_size,
245
            org_num_embeddings=config.vocab_size,
246
        )
247
        self.layers = nn.ModuleList([
248
            LlamaDecoderLayer(config, linear_method)
249
            for _ in range(config.num_hidden_layers)
250
        ])
251
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Woosuk Kwon's avatar
Woosuk Kwon committed
252

253
254
255
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

Woosuk Kwon's avatar
Woosuk Kwon committed
256
257
    def forward(
        self,
258
        input_ids: Optional[torch.Tensor],
259
        positions: torch.Tensor,
260
261
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
262
        inputs_embeds: Optional[torch.Tensor] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
263
    ) -> torch.Tensor:
264
265
266
267
        if inputs_embeds is not None:
            hidden_states = inputs_embeds
        else:
            hidden_states = self.get_input_embeddings(input_ids)
268
        residual = None
Woosuk Kwon's avatar
Woosuk Kwon committed
269
270
        for i in range(len(self.layers)):
            layer = self.layers[i]
271
            hidden_states, residual = layer(
Woosuk Kwon's avatar
Woosuk Kwon committed
272
273
274
                positions,
                hidden_states,
                kv_caches[i],
275
                attn_metadata,
276
                residual,
Woosuk Kwon's avatar
Woosuk Kwon committed
277
            )
278
        hidden_states, _ = self.norm(hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
279
280
281
282
        return hidden_states


class LlamaForCausalLM(nn.Module):
Terry's avatar
Terry committed
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

    # LoRA specific attributes
    supported_lora_modules = [
        "qkv_proj",
        "o_proj",
        "gate_up_proj",
        "down_proj",
        "embed_tokens",
        "lm_head",
    ]
    embedding_modules = {
        "embed_tokens": "input_embeddings",
        "lm_head": "output_embeddings",
    }
    embedding_padding_modules = ["lm_head"]
309

310
311
312
    def __init__(
        self,
        config: LlamaConfig,
313
        linear_method: Optional[LinearMethodBase] = None,
314
        lora_config: Optional[LoRAConfig] = None,
315
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
316
317
        super().__init__()
        self.config = config
318
        self.linear_method = linear_method
319
        self.model = LlamaModel(config, linear_method, lora_config=lora_config)
Terry's avatar
Terry committed
320
        self.unpadded_vocab_size = config.vocab_size
321
        if lora_config:
Terry's avatar
Terry committed
322
            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
323
        self.lm_head = ParallelLMHead(
Terry's avatar
Terry committed
324
            self.unpadded_vocab_size,
325
326
327
328
329
330
331
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
            padding_size=DEFAULT_VOCAB_PADDING_SIZE
            # We need bigger padding if using lora for kernel
            # compatibility
            if not lora_config else lora_config.lora_vocab_padding_size,
        )
332
333
334
335
336

        logit_scale = getattr(config, "logit_scale", 1.0)
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size, logit_scale)
        self.sampler = Sampler()
Woosuk Kwon's avatar
Woosuk Kwon committed
337
338
339

    def forward(
        self,
340
341
        input_ids: torch.Tensor,
        positions: torch.Tensor,
342
343
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
344
    ) -> torch.Tensor:
345
        hidden_states = self.model(input_ids, positions, kv_caches,
346
                                   attn_metadata)
347
348
        return hidden_states

349
350
351
352
353
354
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head.weight, hidden_states,
                                       sampling_metadata)
        return logits

355
356
    def sample(
        self,
357
        logits: torch.Tensor,
358
        sampling_metadata: SamplingMetadata,
359
    ) -> Optional[SamplerOutput]:
360
        next_tokens = self.sampler(logits, sampling_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
361
362
        return next_tokens

363
364
    def load_weights(self,
                     model_name_or_path: str,
365
                     cache_dir: Optional[str] = None,
Jasmond L's avatar
Jasmond L committed
366
367
                     load_format: str = "auto",
                     revision: Optional[str] = None):
368
369
370
371
372
373
374
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
Zhuohan Li's avatar
Zhuohan Li committed
375
        ]
376
        params_dict = dict(self.named_parameters())
377
        for name, loaded_weight in hf_model_weights_iterator(
Jasmond L's avatar
Jasmond L committed
378
                model_name_or_path, cache_dir, load_format, revision):
379
380
            if "rotary_emb.inv_freq" in name:
                continue
381
382
383
384
            if ("rotary_emb.cos_cached" in name
                    or "rotary_emb.sin_cached" in name):
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
385
                continue
386
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
Zhuohan Li's avatar
Zhuohan Li committed
387
                if weight_name not in name:
388
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
389
390
391
392
393
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
394
395
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
396
                break
397
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
398
399
400
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
401
402
403
404
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)