llama.py 12.8 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
23
"""Inference-only LLaMA model compatible with HuggingFace weights."""
24
from typing import Any, Dict, List, Optional, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
25
26
27
28
29

import torch
from torch import nn
from transformers import LlamaConfig

Woosuk Kwon's avatar
Woosuk Kwon committed
30
31
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
Woosuk Kwon's avatar
Woosuk Kwon committed
32
from vllm.model_executor.layers.attention import PagedAttention
33
34
35
36
37
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
                                               MergedColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
Woosuk Kwon's avatar
Woosuk Kwon committed
38
from vllm.model_executor.layers.rotary_embedding import get_rope
Woosuk Kwon's avatar
Woosuk Kwon committed
39
from vllm.model_executor.layers.sampler import Sampler
40
41
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding, ParallelLMHead)
Woosuk Kwon's avatar
Woosuk Kwon committed
42
from vllm.model_executor.parallel_utils.parallel_state import (
43
    get_tensor_model_parallel_world_size)
44
from vllm.model_executor.sampling_metadata import SamplingMetadata
45
46
from vllm.model_executor.weight_utils import (default_weight_loader,
                                              hf_model_weights_iterator)
47
from vllm.sequence import SamplerOutput
Woosuk Kwon's avatar
Woosuk Kwon committed
48
49
50
51
52

KVCache = Tuple[torch.Tensor, torch.Tensor]


class LlamaMLP(nn.Module):
53

Woosuk Kwon's avatar
Woosuk Kwon committed
54
55
56
57
58
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
59
        linear_method: Optional[LinearMethodBase] = None,
60
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
61
        super().__init__()
62
63
64
65
66
67
68
69
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
            bias=False,
            linear_method=linear_method)
        self.down_proj = RowParallelLinear(intermediate_size,
                                           hidden_size,
                                           bias=False,
                                           linear_method=linear_method)
70
71
72
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
Woosuk Kwon's avatar
Woosuk Kwon committed
73
        self.act_fn = SiluAndMul()
Woosuk Kwon's avatar
Woosuk Kwon committed
74
75

    def forward(self, x):
76
        gate_up, _ = self.gate_up_proj(x)
Woosuk Kwon's avatar
Woosuk Kwon committed
77
        x = self.act_fn(gate_up)
Woosuk Kwon's avatar
Woosuk Kwon committed
78
79
80
81
82
83
84
85
86
87
        x, _ = self.down_proj(x)
        return x


class LlamaAttention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
Zhuohan Li's avatar
Zhuohan Li committed
88
        num_kv_heads: int,
Antoni Baum's avatar
Antoni Baum committed
89
        rope_theta: float = 10000,
90
        rope_scaling: Optional[Dict[str, Any]] = None,
91
        max_position_embeddings: int = 8192,
92
        linear_method: Optional[LinearMethodBase] = None,
93
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
94
95
        super().__init__()
        self.hidden_size = hidden_size
Zhuohan Li's avatar
Zhuohan Li committed
96
        tp_size = get_tensor_model_parallel_world_size()
Woosuk Kwon's avatar
Woosuk Kwon committed
97
        self.total_num_heads = num_heads
Zhuohan Li's avatar
Zhuohan Li committed
98
99
100
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
101
102
103
104
105
106
107
108
109
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
110
        self.head_dim = hidden_size // self.total_num_heads
Zhuohan Li's avatar
Zhuohan Li committed
111
112
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
113
        self.scaling = self.head_dim**-0.5
Antoni Baum's avatar
Antoni Baum committed
114
        self.rope_theta = rope_theta
115
        self.max_position_embeddings = max_position_embeddings
Woosuk Kwon's avatar
Woosuk Kwon committed
116

117
        self.qkv_proj = QKVParallelLinear(
Woosuk Kwon's avatar
Woosuk Kwon committed
118
            hidden_size,
Zhuohan Li's avatar
Zhuohan Li committed
119
            self.head_dim,
120
121
            self.total_num_heads,
            self.total_num_kv_heads,
Woosuk Kwon's avatar
Woosuk Kwon committed
122
            bias=False,
123
            linear_method=linear_method,
Woosuk Kwon's avatar
Woosuk Kwon committed
124
        )
125
        self.o_proj = RowParallelLinear(
Woosuk Kwon's avatar
Woosuk Kwon committed
126
127
128
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
129
            linear_method=linear_method,
Woosuk Kwon's avatar
Woosuk Kwon committed
130
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
131
132

        self.rotary_emb = get_rope(
133
134
            self.head_dim,
            rotary_dim=self.head_dim,
Woosuk Kwon's avatar
Woosuk Kwon committed
135
136
137
138
139
140
141
142
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
        self.attn = PagedAttention(self.num_heads,
                                   self.head_dim,
                                   self.scaling,
                                   num_kv_heads=self.num_kv_heads)
Woosuk Kwon's avatar
Woosuk Kwon committed
143
144
145

    def forward(
        self,
146
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
147
148
149
150
151
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:
152
        qkv, _ = self.qkv_proj(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
153
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
154
        q, k = self.rotary_emb(positions, q, k)
155
        k_cache, v_cache = kv_cache
Woosuk Kwon's avatar
Woosuk Kwon committed
156
157
        attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
                                cache_event)
Woosuk Kwon's avatar
Woosuk Kwon committed
158
159
160
161
162
163
        output, _ = self.o_proj(attn_output)
        return output


class LlamaDecoderLayer(nn.Module):

164
165
166
    def __init__(
        self,
        config: LlamaConfig,
167
        linear_method: Optional[LinearMethodBase] = None,
168
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
169
170
        super().__init__()
        self.hidden_size = config.hidden_size
Antoni Baum's avatar
Antoni Baum committed
171
        rope_theta = getattr(config, "rope_theta", 10000)
172
        rope_scaling = getattr(config, "rope_scaling", None)
173
174
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
Woosuk Kwon's avatar
Woosuk Kwon committed
175
176
177
        self.self_attn = LlamaAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
Zhuohan Li's avatar
Zhuohan Li committed
178
            num_kv_heads=config.num_key_value_heads,
Antoni Baum's avatar
Antoni Baum committed
179
            rope_theta=rope_theta,
180
            rope_scaling=rope_scaling,
181
            max_position_embeddings=max_position_embeddings,
182
            linear_method=linear_method,
Woosuk Kwon's avatar
Woosuk Kwon committed
183
184
185
186
187
        )
        self.mlp = LlamaMLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
188
            linear_method=linear_method,
Woosuk Kwon's avatar
Woosuk Kwon committed
189
        )
190
191
192
193
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)
Woosuk Kwon's avatar
Woosuk Kwon committed
194
195
196

    def forward(
        self,
197
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
198
199
200
201
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
202
203
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
204
        # Self Attention
205
206
207
208
209
210
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
211
212
213
214
215
216
217
218
219
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
            input_metadata=input_metadata,
            cache_event=cache_event,
        )

        # Fully Connected
220
221
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
222
        hidden_states = self.mlp(hidden_states)
223
        return hidden_states, residual
Woosuk Kwon's avatar
Woosuk Kwon committed
224
225
226
227


class LlamaModel(nn.Module):

228
229
230
    def __init__(
        self,
        config: LlamaConfig,
231
        linear_method: Optional[LinearMethodBase] = None,
232
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
233
234
235
236
        super().__init__()
        self.config = config
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size
237
        self.embed_tokens = VocabParallelEmbedding(
238
            config.vocab_size,
239
240
            config.hidden_size,
        )
241
        self.layers = nn.ModuleList([
242
            LlamaDecoderLayer(config, linear_method)
243
            for _ in range(config.num_hidden_layers)
244
        ])
245
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Woosuk Kwon's avatar
Woosuk Kwon committed
246
247
248

    def forward(
        self,
249
250
        input_ids: torch.Tensor,
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
251
252
253
254
255
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
    ) -> torch.Tensor:
        hidden_states = self.embed_tokens(input_ids)
256
        residual = None
Woosuk Kwon's avatar
Woosuk Kwon committed
257
        for i in range(len(self.layers)):
258
            cache_event = None if cache_events is None else cache_events[i]
Woosuk Kwon's avatar
Woosuk Kwon committed
259
            layer = self.layers[i]
260
            hidden_states, residual = layer(
Woosuk Kwon's avatar
Woosuk Kwon committed
261
262
263
264
265
                positions,
                hidden_states,
                kv_caches[i],
                input_metadata,
                cache_event,
266
                residual,
Woosuk Kwon's avatar
Woosuk Kwon committed
267
            )
268
        hidden_states, _ = self.norm(hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
269
270
271
272
        return hidden_states


class LlamaForCausalLM(nn.Module):
273

274
275
276
    def __init__(
        self,
        config: LlamaConfig,
277
        linear_method: Optional[LinearMethodBase] = None,
278
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
279
280
        super().__init__()
        self.config = config
281
282
283
        self.linear_method = linear_method
        self.model = LlamaModel(config, linear_method)
        self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
284
        self.sampler = Sampler(config.vocab_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
285
286
287

    def forward(
        self,
288
289
        input_ids: torch.Tensor,
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
290
291
292
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
293
    ) -> torch.Tensor:
294
295
        hidden_states = self.model(input_ids, positions, kv_caches,
                                   input_metadata, cache_events)
296
297
298
299
300
301
302
        return hidden_states

    def sample(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> SamplerOutput:
303
        next_tokens = self.sampler(self.lm_head.weight, hidden_states,
304
                                   sampling_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
305
306
        return next_tokens

307
308
    def load_weights(self,
                     model_name_or_path: str,
309
                     cache_dir: Optional[str] = None,
Jasmond L's avatar
Jasmond L committed
310
311
                     load_format: str = "auto",
                     revision: Optional[str] = None):
312
313
314
315
316
317
318
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
Zhuohan Li's avatar
Zhuohan Li committed
319
        ]
320
        params_dict = dict(self.named_parameters())
321
        for name, loaded_weight in hf_model_weights_iterator(
Jasmond L's avatar
Jasmond L committed
322
                model_name_or_path, cache_dir, load_format, revision):
323
324
            if "rotary_emb.inv_freq" in name:
                continue
325
326
327
328
            if ("rotary_emb.cos_cached" in name
                    or "rotary_emb.sin_cached" in name):
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
329
                continue
330
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
Zhuohan Li's avatar
Zhuohan Li committed
331
                if weight_name not in name:
332
                    continue
333
334
335
                param = params_dict[name.replace(weight_name, param_name)]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
336
                break
337
338
339
340
341
            else:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)