llama.py 12.3 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
23
"""Inference-only LLaMA model compatible with HuggingFace weights."""
24
from typing import Any, Dict, List, Optional, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
25
26
27
28
29

import torch
from torch import nn
from transformers import LlamaConfig

Woosuk Kwon's avatar
Woosuk Kwon committed
30
31
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
Woosuk Kwon's avatar
Woosuk Kwon committed
32
from vllm.model_executor.layers.attention import PagedAttention
33
34
35
36
37
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
                                               MergedColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
Woosuk Kwon's avatar
Woosuk Kwon committed
38
from vllm.model_executor.layers.rotary_embedding import get_rope
Woosuk Kwon's avatar
Woosuk Kwon committed
39
from vllm.model_executor.layers.sampler import Sampler
40
41
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding, ParallelLMHead)
Woosuk Kwon's avatar
Woosuk Kwon committed
42
from vllm.model_executor.parallel_utils.parallel_state import (
43
44
45
    get_tensor_model_parallel_world_size)
from vllm.model_executor.weight_utils import (default_weight_loader,
                                              hf_model_weights_iterator)
46
from vllm.sequence import SamplerOutput
Woosuk Kwon's avatar
Woosuk Kwon committed
47
48
49
50
51

KVCache = Tuple[torch.Tensor, torch.Tensor]


class LlamaMLP(nn.Module):
52

Woosuk Kwon's avatar
Woosuk Kwon committed
53
54
55
56
57
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
58
        linear_method: Optional[LinearMethodBase] = None,
59
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
60
        super().__init__()
61
62
63
64
65
66
67
68
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
            bias=False,
            linear_method=linear_method)
        self.down_proj = RowParallelLinear(intermediate_size,
                                           hidden_size,
                                           bias=False,
                                           linear_method=linear_method)
69
70
71
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
Woosuk Kwon's avatar
Woosuk Kwon committed
72
        self.act_fn = SiluAndMul()
Woosuk Kwon's avatar
Woosuk Kwon committed
73
74

    def forward(self, x):
75
        gate_up, _ = self.gate_up_proj(x)
Woosuk Kwon's avatar
Woosuk Kwon committed
76
        x = self.act_fn(gate_up)
Woosuk Kwon's avatar
Woosuk Kwon committed
77
78
79
80
81
82
83
84
85
86
        x, _ = self.down_proj(x)
        return x


class LlamaAttention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
Zhuohan Li's avatar
Zhuohan Li committed
87
        num_kv_heads: int,
Antoni Baum's avatar
Antoni Baum committed
88
        rope_theta: float = 10000,
89
        rope_scaling: Optional[Dict[str, Any]] = None,
90
        max_position_embeddings: int = 8192,
91
        linear_method: Optional[LinearMethodBase] = None,
92
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
93
94
        super().__init__()
        self.hidden_size = hidden_size
Zhuohan Li's avatar
Zhuohan Li committed
95
        tp_size = get_tensor_model_parallel_world_size()
Woosuk Kwon's avatar
Woosuk Kwon committed
96
        self.total_num_heads = num_heads
Zhuohan Li's avatar
Zhuohan Li committed
97
98
99
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
100
101
102
103
104
105
106
107
108
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
109
        self.head_dim = hidden_size // self.total_num_heads
Zhuohan Li's avatar
Zhuohan Li committed
110
111
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
112
        self.scaling = self.head_dim**-0.5
Antoni Baum's avatar
Antoni Baum committed
113
        self.rope_theta = rope_theta
114
        self.max_position_embeddings = max_position_embeddings
Woosuk Kwon's avatar
Woosuk Kwon committed
115

116
        self.qkv_proj = QKVParallelLinear(
Woosuk Kwon's avatar
Woosuk Kwon committed
117
            hidden_size,
Zhuohan Li's avatar
Zhuohan Li committed
118
            self.head_dim,
119
120
            self.total_num_heads,
            self.total_num_kv_heads,
Woosuk Kwon's avatar
Woosuk Kwon committed
121
            bias=False,
122
            linear_method=linear_method,
Woosuk Kwon's avatar
Woosuk Kwon committed
123
        )
124
        self.o_proj = RowParallelLinear(
Woosuk Kwon's avatar
Woosuk Kwon committed
125
126
127
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
128
            linear_method=linear_method,
Woosuk Kwon's avatar
Woosuk Kwon committed
129
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
130
131

        self.rotary_emb = get_rope(
132
133
            self.head_dim,
            rotary_dim=self.head_dim,
Woosuk Kwon's avatar
Woosuk Kwon committed
134
135
136
137
138
139
140
141
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
        self.attn = PagedAttention(self.num_heads,
                                   self.head_dim,
                                   self.scaling,
                                   num_kv_heads=self.num_kv_heads)
Woosuk Kwon's avatar
Woosuk Kwon committed
142
143
144

    def forward(
        self,
145
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
146
147
148
149
150
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:
151
        qkv, _ = self.qkv_proj(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
152
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
153
        q, k = self.rotary_emb(positions, q, k)
154
        k_cache, v_cache = kv_cache
Woosuk Kwon's avatar
Woosuk Kwon committed
155
156
        attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
                                cache_event)
Woosuk Kwon's avatar
Woosuk Kwon committed
157
158
159
160
161
162
        output, _ = self.o_proj(attn_output)
        return output


class LlamaDecoderLayer(nn.Module):

163
164
165
    def __init__(
        self,
        config: LlamaConfig,
166
        linear_method: Optional[LinearMethodBase] = None,
167
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
168
169
        super().__init__()
        self.hidden_size = config.hidden_size
Antoni Baum's avatar
Antoni Baum committed
170
        rope_theta = getattr(config, "rope_theta", 10000)
171
        rope_scaling = getattr(config, "rope_scaling", None)
172
173
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
Woosuk Kwon's avatar
Woosuk Kwon committed
174
175
176
        self.self_attn = LlamaAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
Zhuohan Li's avatar
Zhuohan Li committed
177
            num_kv_heads=config.num_key_value_heads,
Antoni Baum's avatar
Antoni Baum committed
178
            rope_theta=rope_theta,
179
            rope_scaling=rope_scaling,
180
            max_position_embeddings=max_position_embeddings,
181
            linear_method=linear_method,
Woosuk Kwon's avatar
Woosuk Kwon committed
182
183
184
185
186
        )
        self.mlp = LlamaMLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
187
            linear_method=linear_method,
Woosuk Kwon's avatar
Woosuk Kwon committed
188
        )
189
190
191
192
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)
Woosuk Kwon's avatar
Woosuk Kwon committed
193
194
195

    def forward(
        self,
196
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
197
198
199
200
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
201
202
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
203
        # Self Attention
204
205
206
207
208
209
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
210
211
212
213
214
215
216
217
218
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
            input_metadata=input_metadata,
            cache_event=cache_event,
        )

        # Fully Connected
219
220
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
221
        hidden_states = self.mlp(hidden_states)
222
        return hidden_states, residual
Woosuk Kwon's avatar
Woosuk Kwon committed
223
224
225
226


class LlamaModel(nn.Module):

227
228
229
    def __init__(
        self,
        config: LlamaConfig,
230
        linear_method: Optional[LinearMethodBase] = None,
231
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
232
233
234
235
        super().__init__()
        self.config = config
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size
236
        self.embed_tokens = VocabParallelEmbedding(
237
            config.vocab_size,
238
239
            config.hidden_size,
        )
240
        self.layers = nn.ModuleList([
241
            LlamaDecoderLayer(config, linear_method)
242
            for _ in range(config.num_hidden_layers)
243
        ])
244
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Woosuk Kwon's avatar
Woosuk Kwon committed
245
246
247

    def forward(
        self,
248
249
        input_ids: torch.Tensor,
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
250
251
252
253
254
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
    ) -> torch.Tensor:
        hidden_states = self.embed_tokens(input_ids)
255
        residual = None
Woosuk Kwon's avatar
Woosuk Kwon committed
256
        for i in range(len(self.layers)):
257
            cache_event = None if cache_events is None else cache_events[i]
Woosuk Kwon's avatar
Woosuk Kwon committed
258
            layer = self.layers[i]
259
            hidden_states, residual = layer(
Woosuk Kwon's avatar
Woosuk Kwon committed
260
261
262
263
264
                positions,
                hidden_states,
                kv_caches[i],
                input_metadata,
                cache_event,
265
                residual,
Woosuk Kwon's avatar
Woosuk Kwon committed
266
            )
267
        hidden_states, _ = self.norm(hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
268
269
270
271
        return hidden_states


class LlamaForCausalLM(nn.Module):
272

273
274
275
    def __init__(
        self,
        config: LlamaConfig,
276
        linear_method: Optional[LinearMethodBase] = None,
277
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
278
279
        super().__init__()
        self.config = config
280
281
282
        self.linear_method = linear_method
        self.model = LlamaModel(config, linear_method)
        self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
283
        self.sampler = Sampler(config.vocab_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
284
285
286

    def forward(
        self,
287
288
        input_ids: torch.Tensor,
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
289
290
291
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
292
    ) -> SamplerOutput:
293
294
295
296
        hidden_states = self.model(input_ids, positions, kv_caches,
                                   input_metadata, cache_events)
        next_tokens = self.sampler(self.lm_head.weight, hidden_states,
                                   input_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
297
298
        return next_tokens

299
300
    def load_weights(self,
                     model_name_or_path: str,
301
                     cache_dir: Optional[str] = None,
Jasmond L's avatar
Jasmond L committed
302
303
                     load_format: str = "auto",
                     revision: Optional[str] = None):
304
305
306
307
308
309
310
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
Zhuohan Li's avatar
Zhuohan Li committed
311
        ]
312
        params_dict = dict(self.named_parameters())
313
        for name, loaded_weight in hf_model_weights_iterator(
Jasmond L's avatar
Jasmond L committed
314
                model_name_or_path, cache_dir, load_format, revision):
315
316
            if "rotary_emb.inv_freq" in name:
                continue
317
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
Zhuohan Li's avatar
Zhuohan Li committed
318
                if weight_name not in name:
319
                    continue
320
321
322
                param = params_dict[name.replace(weight_name, param_name)]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
323
                break
324
325
326
327
328
            else:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)