llama.py 12.3 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
23
24
25
26
27
"""Inference-only LLaMA model compatible with HuggingFace weights.

The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
28
from typing import Any, Dict, List, Optional, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
29
30
31
32
33

import torch
from torch import nn
from transformers import LlamaConfig

Woosuk Kwon's avatar
Woosuk Kwon committed
34
35
36
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
37
38
39
40
41
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
                                               MergedColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
Woosuk Kwon's avatar
Woosuk Kwon committed
42
from vllm.model_executor.layers.sampler import Sampler
43
44
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding, ParallelLMHead)
Woosuk Kwon's avatar
Woosuk Kwon committed
45
from vllm.model_executor.parallel_utils.parallel_state import (
46
47
48
    get_tensor_model_parallel_world_size)
from vllm.model_executor.weight_utils import (default_weight_loader,
                                              hf_model_weights_iterator)
49
from vllm.sequence import SamplerOutput
Woosuk Kwon's avatar
Woosuk Kwon committed
50
51
52
53
54

KVCache = Tuple[torch.Tensor, torch.Tensor]


class LlamaMLP(nn.Module):
55

Woosuk Kwon's avatar
Woosuk Kwon committed
56
57
58
59
60
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
61
        linear_method: Optional[LinearMethodBase] = None,
62
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
63
        super().__init__()
64
65
66
67
68
69
70
71
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
            bias=False,
            linear_method=linear_method)
        self.down_proj = RowParallelLinear(intermediate_size,
                                           hidden_size,
                                           bias=False,
                                           linear_method=linear_method)
72
73
74
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
Woosuk Kwon's avatar
Woosuk Kwon committed
75
        self.act_fn = SiluAndMul()
Woosuk Kwon's avatar
Woosuk Kwon committed
76
77

    def forward(self, x):
78
        gate_up, _ = self.gate_up_proj(x)
Woosuk Kwon's avatar
Woosuk Kwon committed
79
        x = self.act_fn(gate_up)
Woosuk Kwon's avatar
Woosuk Kwon committed
80
81
82
83
84
85
86
87
88
89
        x, _ = self.down_proj(x)
        return x


class LlamaAttention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
Zhuohan Li's avatar
Zhuohan Li committed
90
        num_kv_heads: int,
Antoni Baum's avatar
Antoni Baum committed
91
        rope_theta: float = 10000,
92
        rope_scaling: Optional[Dict[str, Any]] = None,
93
        max_position_embeddings: int = 8192,
94
        linear_method: Optional[LinearMethodBase] = None,
95
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
96
97
        super().__init__()
        self.hidden_size = hidden_size
Zhuohan Li's avatar
Zhuohan Li committed
98
        tp_size = get_tensor_model_parallel_world_size()
Woosuk Kwon's avatar
Woosuk Kwon committed
99
        self.total_num_heads = num_heads
Zhuohan Li's avatar
Zhuohan Li committed
100
101
102
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
103
104
105
106
107
108
109
110
111
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
112
        self.head_dim = hidden_size // self.total_num_heads
Zhuohan Li's avatar
Zhuohan Li committed
113
114
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
115
        self.scaling = self.head_dim**-0.5
Antoni Baum's avatar
Antoni Baum committed
116
        self.rope_theta = rope_theta
117
        self.max_position_embeddings = max_position_embeddings
Woosuk Kwon's avatar
Woosuk Kwon committed
118

119
        self.qkv_proj = QKVParallelLinear(
Woosuk Kwon's avatar
Woosuk Kwon committed
120
            hidden_size,
Zhuohan Li's avatar
Zhuohan Li committed
121
            self.head_dim,
122
123
            self.total_num_heads,
            self.total_num_kv_heads,
Woosuk Kwon's avatar
Woosuk Kwon committed
124
            bias=False,
125
            linear_method=linear_method,
Woosuk Kwon's avatar
Woosuk Kwon committed
126
        )
127
        self.o_proj = RowParallelLinear(
Woosuk Kwon's avatar
Woosuk Kwon committed
128
129
130
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
131
            linear_method=linear_method,
Woosuk Kwon's avatar
Woosuk Kwon committed
132
        )
133
134
135
136
137
138
139
        self.attn = PagedAttentionWithRoPE(
            self.num_heads,
            self.head_dim,
            self.scaling,
            base=self.rope_theta,
            max_position=self.max_position_embeddings,
            rotary_dim=self.head_dim,
140
141
            num_kv_heads=self.num_kv_heads,
            rope_scaling=rope_scaling)
Woosuk Kwon's avatar
Woosuk Kwon committed
142
143
144

    def forward(
        self,
145
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
146
147
148
149
150
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:
151
        qkv, _ = self.qkv_proj(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
152
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
153
        k_cache, v_cache = kv_cache
154
155
        attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
                                input_metadata, cache_event)
Woosuk Kwon's avatar
Woosuk Kwon committed
156
157
158
159
160
161
        output, _ = self.o_proj(attn_output)
        return output


class LlamaDecoderLayer(nn.Module):

162
163
164
    def __init__(
        self,
        config: LlamaConfig,
165
        linear_method: Optional[LinearMethodBase] = None,
166
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
167
168
        super().__init__()
        self.hidden_size = config.hidden_size
Antoni Baum's avatar
Antoni Baum committed
169
        rope_theta = getattr(config, "rope_theta", 10000)
170
        rope_scaling = getattr(config, "rope_scaling", None)
171
172
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
Woosuk Kwon's avatar
Woosuk Kwon committed
173
174
175
        self.self_attn = LlamaAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
Zhuohan Li's avatar
Zhuohan Li committed
176
            num_kv_heads=config.num_key_value_heads,
Antoni Baum's avatar
Antoni Baum committed
177
            rope_theta=rope_theta,
178
            rope_scaling=rope_scaling,
179
            max_position_embeddings=max_position_embeddings,
180
            linear_method=linear_method,
Woosuk Kwon's avatar
Woosuk Kwon committed
181
182
183
184
185
        )
        self.mlp = LlamaMLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
186
            linear_method=linear_method,
Woosuk Kwon's avatar
Woosuk Kwon committed
187
        )
188
189
190
191
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)
Woosuk Kwon's avatar
Woosuk Kwon committed
192
193
194

    def forward(
        self,
195
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
196
197
198
199
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
200
201
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
202
        # Self Attention
203
204
205
206
207
208
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
209
210
211
212
213
214
215
216
217
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
            input_metadata=input_metadata,
            cache_event=cache_event,
        )

        # Fully Connected
218
219
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
220
        hidden_states = self.mlp(hidden_states)
221
        return hidden_states, residual
Woosuk Kwon's avatar
Woosuk Kwon committed
222
223
224
225


class LlamaModel(nn.Module):

226
227
228
    def __init__(
        self,
        config: LlamaConfig,
229
        linear_method: Optional[LinearMethodBase] = None,
230
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
231
232
233
234
        super().__init__()
        self.config = config
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size
235
        self.embed_tokens = VocabParallelEmbedding(
236
            config.vocab_size,
237
238
            config.hidden_size,
        )
239
        self.layers = nn.ModuleList([
240
            LlamaDecoderLayer(config, linear_method)
241
            for _ in range(config.num_hidden_layers)
242
        ])
243
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Woosuk Kwon's avatar
Woosuk Kwon committed
244
245
246

    def forward(
        self,
247
248
        input_ids: torch.Tensor,
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
249
250
251
252
253
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
    ) -> torch.Tensor:
        hidden_states = self.embed_tokens(input_ids)
254
        residual = None
Woosuk Kwon's avatar
Woosuk Kwon committed
255
        for i in range(len(self.layers)):
256
            cache_event = None if cache_events is None else cache_events[i]
Woosuk Kwon's avatar
Woosuk Kwon committed
257
            layer = self.layers[i]
258
            hidden_states, residual = layer(
Woosuk Kwon's avatar
Woosuk Kwon committed
259
260
261
262
263
                positions,
                hidden_states,
                kv_caches[i],
                input_metadata,
                cache_event,
264
                residual,
Woosuk Kwon's avatar
Woosuk Kwon committed
265
            )
266
        hidden_states, _ = self.norm(hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
267
268
269
270
        return hidden_states


class LlamaForCausalLM(nn.Module):
271

272
273
274
    def __init__(
        self,
        config: LlamaConfig,
275
        linear_method: Optional[LinearMethodBase] = None,
276
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
277
278
        super().__init__()
        self.config = config
279
280
281
        self.linear_method = linear_method
        self.model = LlamaModel(config, linear_method)
        self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
282
        self.sampler = Sampler(config.vocab_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
283
284
285

    def forward(
        self,
286
287
        input_ids: torch.Tensor,
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
288
289
290
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
291
    ) -> SamplerOutput:
292
293
294
295
        hidden_states = self.model(input_ids, positions, kv_caches,
                                   input_metadata, cache_events)
        next_tokens = self.sampler(self.lm_head.weight, hidden_states,
                                   input_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
296
297
        return next_tokens

298
299
    def load_weights(self,
                     model_name_or_path: str,
300
                     cache_dir: Optional[str] = None,
Jasmond L's avatar
Jasmond L committed
301
302
                     load_format: str = "auto",
                     revision: Optional[str] = None):
303
304
305
306
307
308
309
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
Zhuohan Li's avatar
Zhuohan Li committed
310
        ]
311
        params_dict = dict(self.named_parameters())
312
        for name, loaded_weight in hf_model_weights_iterator(
Jasmond L's avatar
Jasmond L committed
313
                model_name_or_path, cache_dir, load_format, revision):
314
315
            if "rotary_emb.inv_freq" in name:
                continue
316
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
Zhuohan Li's avatar
Zhuohan Li committed
317
                if weight_name not in name:
318
                    continue
319
320
321
                param = params_dict[name.replace(weight_name, param_name)]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
322
                break
323
324
325
326
327
            else:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)