llama.py 13.6 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
23
24
25
26
27
"""Inference-only LLaMA model compatible with HuggingFace weights.

The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
Woosuk Kwon's avatar
Woosuk Kwon committed
28
29
30
31
32
33
from typing import Dict, List, Optional, Tuple

import torch
from torch import nn
from transformers import LlamaConfig

Woosuk Kwon's avatar
Woosuk Kwon committed
34
35
36
37
38
39
40
41
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
                                              load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import (
Woosuk Kwon's avatar
Woosuk Kwon committed
42
    get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
43
from vllm.model_executor.parallel_utils.tensor_parallel import (
44
    VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
Woosuk Kwon's avatar
Woosuk Kwon committed
45
from vllm.sequence import SequenceOutputs
Woosuk Kwon's avatar
Woosuk Kwon committed
46
47
48
49
50

KVCache = Tuple[torch.Tensor, torch.Tensor]


class LlamaMLP(nn.Module):
51

Woosuk Kwon's avatar
Woosuk Kwon committed
52
53
54
55
56
57
58
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
    ):
        super().__init__()
59
60
61
62
        self.gate_up_proj = ColumnParallelLinear(hidden_size,
                                                 2 * intermediate_size,
                                                 bias=False,
                                                 gather_output=False,
63
                                                 perform_initialization=False)
64
65
66
67
        self.down_proj = RowParallelLinear(intermediate_size,
                                           hidden_size,
                                           bias=False,
                                           input_is_parallel=True,
Woosuk Kwon's avatar
Woosuk Kwon committed
68
                                           perform_initialization=False)
69
70
71
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
Woosuk Kwon's avatar
Woosuk Kwon committed
72
        self.act_fn = SiluAndMul()
Woosuk Kwon's avatar
Woosuk Kwon committed
73
74

    def forward(self, x):
75
        gate_up, _ = self.gate_up_proj(x)
Woosuk Kwon's avatar
Woosuk Kwon committed
76
        x = self.act_fn(gate_up)
Woosuk Kwon's avatar
Woosuk Kwon committed
77
78
79
80
81
82
83
84
85
86
        x, _ = self.down_proj(x)
        return x


class LlamaAttention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
Zhuohan Li's avatar
Zhuohan Li committed
87
        num_kv_heads: int,
Antoni Baum's avatar
Antoni Baum committed
88
        rope_theta: float = 10000,
Woosuk Kwon's avatar
Woosuk Kwon committed
89
90
91
    ):
        super().__init__()
        self.hidden_size = hidden_size
Zhuohan Li's avatar
Zhuohan Li committed
92
        tp_size = get_tensor_model_parallel_world_size()
Woosuk Kwon's avatar
Woosuk Kwon committed
93
        self.total_num_heads = num_heads
Zhuohan Li's avatar
Zhuohan Li committed
94
95
96
97
98
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        assert self.total_num_kv_heads % tp_size == 0
        self.num_kv_heads = self.total_num_kv_heads // tp_size
Woosuk Kwon's avatar
Woosuk Kwon committed
99
        self.head_dim = hidden_size // self.total_num_heads
Zhuohan Li's avatar
Zhuohan Li committed
100
101
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
102
        self.scaling = self.head_dim**-0.5
Antoni Baum's avatar
Antoni Baum committed
103
        self.rope_theta = rope_theta
Woosuk Kwon's avatar
Woosuk Kwon committed
104

105
        self.qkv_proj = ColumnParallelLinear(
Woosuk Kwon's avatar
Woosuk Kwon committed
106
            hidden_size,
Zhuohan Li's avatar
Zhuohan Li committed
107
108
            (self.total_num_heads + 2 * self.total_num_kv_heads) *
            self.head_dim,
Woosuk Kwon's avatar
Woosuk Kwon committed
109
110
111
112
113
114
115
116
117
118
119
            bias=False,
            gather_output=False,
            perform_initialization=False,
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
            input_is_parallel=True,
            perform_initialization=False,
        )
120
121
122
        self.attn = PagedAttentionWithRoPE(self.num_heads,
                                           self.head_dim,
                                           self.scaling,
Antoni Baum's avatar
Antoni Baum committed
123
                                           base=self.rope_theta,
Zhuohan Li's avatar
Zhuohan Li committed
124
125
                                           rotary_dim=self.head_dim,
                                           num_kv_heads=self.num_kv_heads)
Woosuk Kwon's avatar
Woosuk Kwon committed
126
127
128

    def forward(
        self,
129
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
130
131
132
133
134
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:
135
        qkv, _ = self.qkv_proj(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
136
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
137
        k_cache, v_cache = kv_cache
138
139
        attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
                                input_metadata, cache_event)
Woosuk Kwon's avatar
Woosuk Kwon committed
140
141
142
143
144
145
146
147
148
        output, _ = self.o_proj(attn_output)
        return output


class LlamaDecoderLayer(nn.Module):

    def __init__(self, config: LlamaConfig):
        super().__init__()
        self.hidden_size = config.hidden_size
Antoni Baum's avatar
Antoni Baum committed
149
150
        # Requires transformers > 4.32.0
        rope_theta = getattr(config, "rope_theta", 10000)
Woosuk Kwon's avatar
Woosuk Kwon committed
151
152
153
        self.self_attn = LlamaAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
Zhuohan Li's avatar
Zhuohan Li committed
154
            num_kv_heads=config.num_key_value_heads,
Antoni Baum's avatar
Antoni Baum committed
155
            rope_theta=rope_theta,
Woosuk Kwon's avatar
Woosuk Kwon committed
156
157
158
159
160
161
        )
        self.mlp = LlamaMLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
        )
162
163
164
165
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)
Woosuk Kwon's avatar
Woosuk Kwon committed
166
167
168

    def forward(
        self,
169
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:
        # Self Attention
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
            input_metadata=input_metadata,
            cache_event=cache_event,
        )
        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states
        return hidden_states


class LlamaModel(nn.Module):

    def __init__(self, config: LlamaConfig):
        super().__init__()
        self.config = config
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

203
        vocab_size = ((config.vocab_size + 63) // 64) * 64
204
        self.embed_tokens = VocabParallelEmbedding(
205
            vocab_size, config.hidden_size, perform_initialization=False)
206
207
208
        self.layers = nn.ModuleList([
            LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)
        ])
209
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Woosuk Kwon's avatar
Woosuk Kwon committed
210
211
212

    def forward(
        self,
213
214
        input_ids: torch.Tensor,
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
    ) -> torch.Tensor:
        hidden_states = self.embed_tokens(input_ids)
        for i in range(len(self.layers)):
            if cache_events is None:
                cache_event = None
            else:
                cache_event = cache_events[i]
            layer = self.layers[i]
            hidden_states = layer(
                positions,
                hidden_states,
                kv_caches[i],
                input_metadata,
                cache_event,
            )
        hidden_states = self.norm(hidden_states)
        return hidden_states


class LlamaForCausalLM(nn.Module):
238

Woosuk Kwon's avatar
Woosuk Kwon committed
239
240
241
242
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.model = LlamaModel(config)
243
        vocab_size = ((config.vocab_size + 63) // 64) * 64
Woosuk Kwon's avatar
Woosuk Kwon committed
244
        self.lm_head = ColumnParallelLinear(config.hidden_size,
245
                                            vocab_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
246
247
248
                                            bias=False,
                                            gather_output=False,
                                            perform_initialization=False)
Woosuk Kwon's avatar
Woosuk Kwon committed
249
        self.sampler = Sampler(config.vocab_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
250
251
252

    def forward(
        self,
253
254
        input_ids: torch.Tensor,
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
255
256
257
258
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
    ) -> Dict[int, SequenceOutputs]:
259
260
261
262
        hidden_states = self.model(input_ids, positions, kv_caches,
                                   input_metadata, cache_events)
        next_tokens = self.sampler(self.lm_head.weight, hidden_states,
                                   input_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
263
264
        return next_tokens

265
266
267
268
    _column_parallel_weights = [
        "embed_tokens.weight", "lm_head.weight", "qkv_proj.weight",
        "gate_proj.weight", "up_proj.weight"
    ]
Woosuk Kwon's avatar
Woosuk Kwon committed
269
270
    _row_parallel_weights = ["o_proj.weight", "down_proj.weight"]

271
272
    def load_weights(self,
                     model_name_or_path: str,
273
274
                     cache_dir: Optional[str] = None,
                     use_np_cache: bool = False):
Zhuohan Li's avatar
Zhuohan Li committed
275
        tp_size = get_tensor_model_parallel_world_size()
Woosuk Kwon's avatar
Woosuk Kwon committed
276
        tensor_model_parallel_rank = get_tensor_model_parallel_rank()
Zhuohan Li's avatar
Zhuohan Li committed
277
278
279
280
281
282
283
284
285
286
287
        q_proj_shard_size = (self.config.hidden_size // tp_size)
        kv_proj_shard_size = (self.config.hidden_size //
                              self.config.num_attention_heads *
                              self.config.num_key_value_heads // tp_size)
        attention_weight_specs = [
            # (weight_name, shard_size, offset)
            ("q_proj", q_proj_shard_size, 0),
            ("k_proj", kv_proj_shard_size, q_proj_shard_size),
            ("v_proj", kv_proj_shard_size,
             q_proj_shard_size + kv_proj_shard_size),
        ]
Woosuk Kwon's avatar
Woosuk Kwon committed
288
        state_dict = self.state_dict()
289
290

        for name, loaded_weight in hf_model_weights_iterator(
291
                model_name_or_path, cache_dir, use_np_cache):
292
293
294
            if "rotary_emb.inv_freq" in name:
                continue

295
296
297
            if "embed_tokens" in name or "lm_head" in name:
                param = state_dict[name]
                # Consider padding in the vocab size.
Zhuohan Li's avatar
Zhuohan Li committed
298
                padded_vocab_size = (param.shape[0] * tp_size)
299
300
301
302
303
304
                num_extra_rows = padded_vocab_size - self.config.vocab_size
                extra_rows = torch.empty(num_extra_rows,
                                         loaded_weight.shape[1])
                extra_rows = extra_rows.to(loaded_weight)
                loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)

305
            is_attention_weight = False
Zhuohan Li's avatar
Zhuohan Li committed
306
307
            for weight_name, shard_size, offset in attention_weight_specs:
                if weight_name not in name:
308
                    continue
Zhuohan Li's avatar
Zhuohan Li committed
309
310
                param = state_dict[name.replace(weight_name, "qkv_proj")]

311
                loaded_weight = loaded_weight[
312
313
                    shard_size * tensor_model_parallel_rank:shard_size *
                    (tensor_model_parallel_rank + 1)]
Zhuohan Li's avatar
Zhuohan Li committed
314
                param_slice = param.data[offset:offset + shard_size]
315
                assert param_slice.shape == loaded_weight.shape
Zhuohan Li's avatar
Zhuohan Li committed
316

317
318
319
320
321
322
323
324
325
326
327
328
329
                param_slice.copy_(loaded_weight)
                is_attention_weight = True
                break
            if is_attention_weight:
                continue

            is_gate_up_weight = False
            for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
                if weight_name not in name:
                    continue
                param = state_dict[name.replace(weight_name, "gate_up_proj")]
                shard_size = param.shape[0] // 2
                loaded_weight = loaded_weight[
330
331
332
333
                    shard_size * tensor_model_parallel_rank:shard_size *
                    (tensor_model_parallel_rank + 1)]
                param_slice = param.data[shard_size * stride_id:shard_size *
                                         (stride_id + 1)]
334
335
336
337
338
339
340
341
342
343
                assert param_slice.shape == loaded_weight.shape
                param_slice.copy_(loaded_weight)
                is_gate_up_weight = True
                break
            if is_gate_up_weight:
                continue

            param = state_dict[name]
            load_tensor_parallel_weights(param, loaded_weight, name,
                                         self._column_parallel_weights,
344
345
                                         self._row_parallel_weights,
                                         tensor_model_parallel_rank)