llama.py 12.1 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
23
"""Inference-only LLaMA model compatible with HuggingFace weights."""
24
from typing import Any, Dict, List, Optional, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
25
26
27
28
29

import torch
from torch import nn
from transformers import LlamaConfig

Woosuk Kwon's avatar
Woosuk Kwon committed
30
31
32
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
33
34
35
36
37
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
                                               MergedColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
Woosuk Kwon's avatar
Woosuk Kwon committed
38
from vllm.model_executor.layers.sampler import Sampler
39
40
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding, ParallelLMHead)
Woosuk Kwon's avatar
Woosuk Kwon committed
41
from vllm.model_executor.parallel_utils.parallel_state import (
42
43
44
    get_tensor_model_parallel_world_size)
from vllm.model_executor.weight_utils import (default_weight_loader,
                                              hf_model_weights_iterator)
45
from vllm.sequence import SamplerOutput
Woosuk Kwon's avatar
Woosuk Kwon committed
46
47
48
49
50

KVCache = Tuple[torch.Tensor, torch.Tensor]


class LlamaMLP(nn.Module):
51

Woosuk Kwon's avatar
Woosuk Kwon committed
52
53
54
55
56
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
57
        linear_method: Optional[LinearMethodBase] = None,
58
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
59
        super().__init__()
60
61
62
63
64
65
66
67
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
            bias=False,
            linear_method=linear_method)
        self.down_proj = RowParallelLinear(intermediate_size,
                                           hidden_size,
                                           bias=False,
                                           linear_method=linear_method)
68
69
70
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
Woosuk Kwon's avatar
Woosuk Kwon committed
71
        self.act_fn = SiluAndMul()
Woosuk Kwon's avatar
Woosuk Kwon committed
72
73

    def forward(self, x):
74
        gate_up, _ = self.gate_up_proj(x)
Woosuk Kwon's avatar
Woosuk Kwon committed
75
        x = self.act_fn(gate_up)
Woosuk Kwon's avatar
Woosuk Kwon committed
76
77
78
79
80
81
82
83
84
85
        x, _ = self.down_proj(x)
        return x


class LlamaAttention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
Zhuohan Li's avatar
Zhuohan Li committed
86
        num_kv_heads: int,
Antoni Baum's avatar
Antoni Baum committed
87
        rope_theta: float = 10000,
88
        rope_scaling: Optional[Dict[str, Any]] = None,
89
        max_position_embeddings: int = 8192,
90
        linear_method: Optional[LinearMethodBase] = None,
91
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
92
93
        super().__init__()
        self.hidden_size = hidden_size
Zhuohan Li's avatar
Zhuohan Li committed
94
        tp_size = get_tensor_model_parallel_world_size()
Woosuk Kwon's avatar
Woosuk Kwon committed
95
        self.total_num_heads = num_heads
Zhuohan Li's avatar
Zhuohan Li committed
96
97
98
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
99
100
101
102
103
104
105
106
107
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
108
        self.head_dim = hidden_size // self.total_num_heads
Zhuohan Li's avatar
Zhuohan Li committed
109
110
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
111
        self.scaling = self.head_dim**-0.5
Antoni Baum's avatar
Antoni Baum committed
112
        self.rope_theta = rope_theta
113
        self.max_position_embeddings = max_position_embeddings
Woosuk Kwon's avatar
Woosuk Kwon committed
114

115
        self.qkv_proj = QKVParallelLinear(
Woosuk Kwon's avatar
Woosuk Kwon committed
116
            hidden_size,
Zhuohan Li's avatar
Zhuohan Li committed
117
            self.head_dim,
118
119
            self.total_num_heads,
            self.total_num_kv_heads,
Woosuk Kwon's avatar
Woosuk Kwon committed
120
            bias=False,
121
            linear_method=linear_method,
Woosuk Kwon's avatar
Woosuk Kwon committed
122
        )
123
        self.o_proj = RowParallelLinear(
Woosuk Kwon's avatar
Woosuk Kwon committed
124
125
126
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
127
            linear_method=linear_method,
Woosuk Kwon's avatar
Woosuk Kwon committed
128
        )
129
130
131
132
133
134
135
        self.attn = PagedAttentionWithRoPE(
            self.num_heads,
            self.head_dim,
            self.scaling,
            base=self.rope_theta,
            max_position=self.max_position_embeddings,
            rotary_dim=self.head_dim,
136
137
            num_kv_heads=self.num_kv_heads,
            rope_scaling=rope_scaling)
Woosuk Kwon's avatar
Woosuk Kwon committed
138
139
140

    def forward(
        self,
141
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
142
143
144
145
146
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:
147
        qkv, _ = self.qkv_proj(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
148
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
149
        k_cache, v_cache = kv_cache
150
151
        attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
                                input_metadata, cache_event)
Woosuk Kwon's avatar
Woosuk Kwon committed
152
153
154
155
156
157
        output, _ = self.o_proj(attn_output)
        return output


class LlamaDecoderLayer(nn.Module):

158
159
160
    def __init__(
        self,
        config: LlamaConfig,
161
        linear_method: Optional[LinearMethodBase] = None,
162
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
163
164
        super().__init__()
        self.hidden_size = config.hidden_size
Antoni Baum's avatar
Antoni Baum committed
165
        rope_theta = getattr(config, "rope_theta", 10000)
166
        rope_scaling = getattr(config, "rope_scaling", None)
167
168
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
Woosuk Kwon's avatar
Woosuk Kwon committed
169
170
171
        self.self_attn = LlamaAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
Zhuohan Li's avatar
Zhuohan Li committed
172
            num_kv_heads=config.num_key_value_heads,
Antoni Baum's avatar
Antoni Baum committed
173
            rope_theta=rope_theta,
174
            rope_scaling=rope_scaling,
175
            max_position_embeddings=max_position_embeddings,
176
            linear_method=linear_method,
Woosuk Kwon's avatar
Woosuk Kwon committed
177
178
179
180
181
        )
        self.mlp = LlamaMLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
182
            linear_method=linear_method,
Woosuk Kwon's avatar
Woosuk Kwon committed
183
        )
184
185
186
187
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)
Woosuk Kwon's avatar
Woosuk Kwon committed
188
189
190

    def forward(
        self,
191
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
192
193
194
195
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
196
197
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
198
        # Self Attention
199
200
201
202
203
204
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
205
206
207
208
209
210
211
212
213
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
            input_metadata=input_metadata,
            cache_event=cache_event,
        )

        # Fully Connected
214
215
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
216
        hidden_states = self.mlp(hidden_states)
217
        return hidden_states, residual
Woosuk Kwon's avatar
Woosuk Kwon committed
218
219
220
221


class LlamaModel(nn.Module):

222
223
224
    def __init__(
        self,
        config: LlamaConfig,
225
        linear_method: Optional[LinearMethodBase] = None,
226
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
227
228
229
230
        super().__init__()
        self.config = config
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size
231
        self.embed_tokens = VocabParallelEmbedding(
232
            config.vocab_size,
233
234
            config.hidden_size,
        )
235
        self.layers = nn.ModuleList([
236
            LlamaDecoderLayer(config, linear_method)
237
            for _ in range(config.num_hidden_layers)
238
        ])
239
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Woosuk Kwon's avatar
Woosuk Kwon committed
240
241
242

    def forward(
        self,
243
244
        input_ids: torch.Tensor,
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
245
246
247
248
249
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
    ) -> torch.Tensor:
        hidden_states = self.embed_tokens(input_ids)
250
        residual = None
Woosuk Kwon's avatar
Woosuk Kwon committed
251
        for i in range(len(self.layers)):
252
            cache_event = None if cache_events is None else cache_events[i]
Woosuk Kwon's avatar
Woosuk Kwon committed
253
            layer = self.layers[i]
254
            hidden_states, residual = layer(
Woosuk Kwon's avatar
Woosuk Kwon committed
255
256
257
258
259
                positions,
                hidden_states,
                kv_caches[i],
                input_metadata,
                cache_event,
260
                residual,
Woosuk Kwon's avatar
Woosuk Kwon committed
261
            )
262
        hidden_states, _ = self.norm(hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
263
264
265
266
        return hidden_states


class LlamaForCausalLM(nn.Module):
267

268
269
270
    def __init__(
        self,
        config: LlamaConfig,
271
        linear_method: Optional[LinearMethodBase] = None,
272
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
273
274
        super().__init__()
        self.config = config
275
276
277
        self.linear_method = linear_method
        self.model = LlamaModel(config, linear_method)
        self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
278
        self.sampler = Sampler(config.vocab_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
279
280
281

    def forward(
        self,
282
283
        input_ids: torch.Tensor,
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
284
285
286
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
287
    ) -> SamplerOutput:
288
289
290
291
        hidden_states = self.model(input_ids, positions, kv_caches,
                                   input_metadata, cache_events)
        next_tokens = self.sampler(self.lm_head.weight, hidden_states,
                                   input_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
292
293
        return next_tokens

294
295
    def load_weights(self,
                     model_name_or_path: str,
296
                     cache_dir: Optional[str] = None,
Jasmond L's avatar
Jasmond L committed
297
298
                     load_format: str = "auto",
                     revision: Optional[str] = None):
299
300
301
302
303
304
305
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
Zhuohan Li's avatar
Zhuohan Li committed
306
        ]
307
        params_dict = dict(self.named_parameters())
308
        for name, loaded_weight in hf_model_weights_iterator(
Jasmond L's avatar
Jasmond L committed
309
                model_name_or_path, cache_dir, load_format, revision):
310
311
            if "rotary_emb.inv_freq" in name:
                continue
312
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
Zhuohan Li's avatar
Zhuohan Li committed
313
                if weight_name not in name:
314
                    continue
315
316
317
                param = params_dict[name.replace(weight_name, param_name)]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
318
                break
319
320
321
322
323
            else:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)