llama.py 15.7 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
23
24
25
26
27
"""Inference-only LLaMA model compatible with HuggingFace weights.

The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
28
from typing import List, Optional, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
29
30
31
32
33

import torch
from torch import nn
from transformers import LlamaConfig

Woosuk Kwon's avatar
Woosuk Kwon committed
34
35
36
37
38
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
from vllm.model_executor.layers.sampler import Sampler
39
from vllm.model_executor.layers.quantized_linear import ParallelLinear
Woosuk Kwon's avatar
Woosuk Kwon committed
40
from vllm.model_executor.parallel_utils.parallel_state import (
Woosuk Kwon's avatar
Woosuk Kwon committed
41
    get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
42
from vllm.model_executor.parallel_utils.tensor_parallel import (
43
44
45
    VocabParallelEmbedding)
from vllm.model_executor.quantization_utils import QuantizationConfig
from vllm.model_executor.weight_utils import (
Woosuk Kwon's avatar
Woosuk Kwon committed
46
47
    convert_pyslice_to_tensor, hf_model_weights_iterator,
    load_tensor_parallel_weights, load_padded_tensor_parallel_vocab)
48
from vllm.sequence import SamplerOutput
Woosuk Kwon's avatar
Woosuk Kwon committed
49
50
51
52
53

KVCache = Tuple[torch.Tensor, torch.Tensor]


class LlamaMLP(nn.Module):
54

Woosuk Kwon's avatar
Woosuk Kwon committed
55
56
57
58
59
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
60
61
        quant_config: Optional[QuantizationConfig] = None,
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
62
        super().__init__()
63
64
65
66
67
68
69
70
71
72
73
74
        self.gate_up_proj = ParallelLinear.column(hidden_size,
                                                  2 * intermediate_size,
                                                  bias=False,
                                                  gather_output=False,
                                                  perform_initialization=False,
                                                  quant_config=quant_config)
        self.down_proj = ParallelLinear.row(intermediate_size,
                                            hidden_size,
                                            bias=False,
                                            input_is_parallel=True,
                                            perform_initialization=False,
                                            quant_config=quant_config)
75
76
77
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
Woosuk Kwon's avatar
Woosuk Kwon committed
78
        self.act_fn = SiluAndMul()
Woosuk Kwon's avatar
Woosuk Kwon committed
79
80

    def forward(self, x):
81
        gate_up, _ = self.gate_up_proj(x)
Woosuk Kwon's avatar
Woosuk Kwon committed
82
        x = self.act_fn(gate_up)
Woosuk Kwon's avatar
Woosuk Kwon committed
83
84
85
86
87
88
89
90
91
92
        x, _ = self.down_proj(x)
        return x


class LlamaAttention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
Zhuohan Li's avatar
Zhuohan Li committed
93
        num_kv_heads: int,
Antoni Baum's avatar
Antoni Baum committed
94
        rope_theta: float = 10000,
95
        max_position_embeddings: int = 8192,
96
97
        quant_config: Optional[QuantizationConfig] = None,
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
98
99
        super().__init__()
        self.hidden_size = hidden_size
Zhuohan Li's avatar
Zhuohan Li committed
100
        tp_size = get_tensor_model_parallel_world_size()
Woosuk Kwon's avatar
Woosuk Kwon committed
101
        self.total_num_heads = num_heads
Zhuohan Li's avatar
Zhuohan Li committed
102
103
104
105
106
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        assert self.total_num_kv_heads % tp_size == 0
        self.num_kv_heads = self.total_num_kv_heads // tp_size
Woosuk Kwon's avatar
Woosuk Kwon committed
107
        self.head_dim = hidden_size // self.total_num_heads
Zhuohan Li's avatar
Zhuohan Li committed
108
109
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
110
        self.scaling = self.head_dim**-0.5
Antoni Baum's avatar
Antoni Baum committed
111
        self.rope_theta = rope_theta
112
        self.max_position_embeddings = max_position_embeddings
Woosuk Kwon's avatar
Woosuk Kwon committed
113

114
        self.qkv_proj = ParallelLinear.column(
Woosuk Kwon's avatar
Woosuk Kwon committed
115
            hidden_size,
Zhuohan Li's avatar
Zhuohan Li committed
116
117
            (self.total_num_heads + 2 * self.total_num_kv_heads) *
            self.head_dim,
Woosuk Kwon's avatar
Woosuk Kwon committed
118
119
120
            bias=False,
            gather_output=False,
            perform_initialization=False,
121
            quant_config=quant_config,
Woosuk Kwon's avatar
Woosuk Kwon committed
122
        )
123
        self.o_proj = ParallelLinear.row(
Woosuk Kwon's avatar
Woosuk Kwon committed
124
125
126
127
128
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
            input_is_parallel=True,
            perform_initialization=False,
129
            quant_config=quant_config,
Woosuk Kwon's avatar
Woosuk Kwon committed
130
        )
131
132
133
134
135
136
137
138
        self.attn = PagedAttentionWithRoPE(
            self.num_heads,
            self.head_dim,
            self.scaling,
            base=self.rope_theta,
            max_position=self.max_position_embeddings,
            rotary_dim=self.head_dim,
            num_kv_heads=self.num_kv_heads)
Woosuk Kwon's avatar
Woosuk Kwon committed
139
140
141

    def forward(
        self,
142
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
143
144
145
146
147
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:
148
        qkv, _ = self.qkv_proj(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
149
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
150
        k_cache, v_cache = kv_cache
151
152
        attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
                                input_metadata, cache_event)
Woosuk Kwon's avatar
Woosuk Kwon committed
153
154
155
156
157
158
        output, _ = self.o_proj(attn_output)
        return output


class LlamaDecoderLayer(nn.Module):

159
160
161
162
163
    def __init__(
        self,
        config: LlamaConfig,
        quant_config: Optional[QuantizationConfig] = None,
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
164
165
        super().__init__()
        self.hidden_size = config.hidden_size
Antoni Baum's avatar
Antoni Baum committed
166
167
        # Requires transformers > 4.32.0
        rope_theta = getattr(config, "rope_theta", 10000)
168
169
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
Woosuk Kwon's avatar
Woosuk Kwon committed
170
171
172
        self.self_attn = LlamaAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
Zhuohan Li's avatar
Zhuohan Li committed
173
            num_kv_heads=config.num_key_value_heads,
Antoni Baum's avatar
Antoni Baum committed
174
            rope_theta=rope_theta,
175
            max_position_embeddings=max_position_embeddings,
176
            quant_config=quant_config,
Woosuk Kwon's avatar
Woosuk Kwon committed
177
178
179
180
181
        )
        self.mlp = LlamaMLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
182
            quant_config=quant_config,
Woosuk Kwon's avatar
Woosuk Kwon committed
183
        )
184
185
186
187
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)
Woosuk Kwon's avatar
Woosuk Kwon committed
188
189
190

    def forward(
        self,
191
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:
        # Self Attention
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
            input_metadata=input_metadata,
            cache_event=cache_event,
        )
        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states
        return hidden_states


class LlamaModel(nn.Module):

219
220
221
222
223
    def __init__(
        self,
        config: LlamaConfig,
        quant_config: Optional[QuantizationConfig] = None,
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
224
225
226
227
228
        super().__init__()
        self.config = config
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

229
        vocab_size = ((config.vocab_size + 63) // 64) * 64
230
        self.embed_tokens = VocabParallelEmbedding(
231
            vocab_size, config.hidden_size, perform_initialization=False)
232
        self.layers = nn.ModuleList([
233
234
            LlamaDecoderLayer(config, quant_config)
            for _ in range(config.num_hidden_layers)
235
        ])
236
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Woosuk Kwon's avatar
Woosuk Kwon committed
237
238
239

    def forward(
        self,
240
241
        input_ids: torch.Tensor,
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
    ) -> torch.Tensor:
        hidden_states = self.embed_tokens(input_ids)
        for i in range(len(self.layers)):
            if cache_events is None:
                cache_event = None
            else:
                cache_event = cache_events[i]
            layer = self.layers[i]
            hidden_states = layer(
                positions,
                hidden_states,
                kv_caches[i],
                input_metadata,
                cache_event,
            )
        hidden_states = self.norm(hidden_states)
        return hidden_states


class LlamaForCausalLM(nn.Module):
265

266
267
268
269
270
    def __init__(
        self,
        config: LlamaConfig,
        quant_config: Optional[QuantizationConfig] = None,
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
271
272
        super().__init__()
        self.config = config
273
274
        self.quant_config = quant_config
        self.model = LlamaModel(config, quant_config)
275
        vocab_size = ((config.vocab_size + 63) // 64) * 64
276
277
278
279
280
281
282
        # NOTE: The LM head is not quantized.
        self.lm_head = ParallelLinear.column(config.hidden_size,
                                             vocab_size,
                                             bias=False,
                                             gather_output=False,
                                             perform_initialization=False,
                                             quant_config=None)
Woosuk Kwon's avatar
Woosuk Kwon committed
283
        self.sampler = Sampler(config.vocab_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
284
285
286

    def forward(
        self,
287
288
        input_ids: torch.Tensor,
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
289
290
291
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
292
    ) -> SamplerOutput:
293
294
295
296
        hidden_states = self.model(input_ids, positions, kv_caches,
                                   input_metadata, cache_events)
        next_tokens = self.sampler(self.lm_head.weight, hidden_states,
                                   input_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
297
298
        return next_tokens

299
300
    _column_parallel_layers = []
    _row_parallel_layers = ["o_proj", "down_proj"]
Woosuk Kwon's avatar
Woosuk Kwon committed
301

302
303
    def load_weights(self,
                     model_name_or_path: str,
304
                     cache_dir: Optional[str] = None,
Jasmond L's avatar
Jasmond L committed
305
306
                     load_format: str = "auto",
                     revision: Optional[str] = None):
307
308
309
310
311
312
313
314
315
316
317
318
319
320
        if self.quant_config is None:
            weight_suffixes = ["weight"]
        else:
            weight_suffixes = self.quant_config.get_tp_tensor_names()

        column_parallel_weights: List[str] = []
        for layer in self._column_parallel_layers:
            for suffix in weight_suffixes:
                column_parallel_weights.append(f"{layer}.{suffix}")
        row_parallel_weights: List[str] = []
        for layer in self._row_parallel_layers:
            for suffix in weight_suffixes:
                row_parallel_weights.append(f"{layer}.{suffix}")

Zhuohan Li's avatar
Zhuohan Li committed
321
        tp_size = get_tensor_model_parallel_world_size()
Woosuk Kwon's avatar
Woosuk Kwon committed
322
        tensor_model_parallel_rank = get_tensor_model_parallel_rank()
Zhuohan Li's avatar
Zhuohan Li committed
323
324
325
326
327
328
329
330
331
332
333
        q_proj_shard_size = (self.config.hidden_size // tp_size)
        kv_proj_shard_size = (self.config.hidden_size //
                              self.config.num_attention_heads *
                              self.config.num_key_value_heads // tp_size)
        attention_weight_specs = [
            # (weight_name, shard_size, offset)
            ("q_proj", q_proj_shard_size, 0),
            ("k_proj", kv_proj_shard_size, q_proj_shard_size),
            ("v_proj", kv_proj_shard_size,
             q_proj_shard_size + kv_proj_shard_size),
        ]
Woosuk Kwon's avatar
Woosuk Kwon committed
334
        state_dict = self.state_dict()
335
336

        for name, loaded_weight in hf_model_weights_iterator(
Jasmond L's avatar
Jasmond L committed
337
                model_name_or_path, cache_dir, load_format, revision):
338
339
340
            if "rotary_emb.inv_freq" in name:
                continue

341
342
343
344
345
346
            is_packed = False
            is_transposed = False
            if self.quant_config is not None:
                is_packed = self.quant_config.is_packed(name)
                is_transposed = self.quant_config.is_transposed(name)
            if is_transposed:
Woosuk Kwon's avatar
Woosuk Kwon committed
347
                loaded_weight = convert_pyslice_to_tensor(loaded_weight)
348
349
                loaded_weight = loaded_weight.T

350
            is_attention_weight = False
Zhuohan Li's avatar
Zhuohan Li committed
351
352
            for weight_name, shard_size, offset in attention_weight_specs:
                if weight_name not in name:
353
                    continue
Zhuohan Li's avatar
Zhuohan Li committed
354
                param = state_dict[name.replace(weight_name, "qkv_proj")]
355
356
357
358
359
360
                if is_transposed:
                    param = param.T

                if is_packed:
                    shard_size //= self.quant_config.pack_factor
                    offset //= self.quant_config.pack_factor
Zhuohan Li's avatar
Zhuohan Li committed
361

362
                loaded_weight = loaded_weight[
363
364
                    shard_size * tensor_model_parallel_rank:shard_size *
                    (tensor_model_parallel_rank + 1)]
Zhuohan Li's avatar
Zhuohan Li committed
365
                param_slice = param.data[offset:offset + shard_size]
366
                assert param_slice.shape == loaded_weight.shape
Zhuohan Li's avatar
Zhuohan Li committed
367

368
369
370
371
372
373
374
375
376
377
378
                param_slice.copy_(loaded_weight)
                is_attention_weight = True
                break
            if is_attention_weight:
                continue

            is_gate_up_weight = False
            for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
                if weight_name not in name:
                    continue
                param = state_dict[name.replace(weight_name, "gate_up_proj")]
379
380
381
                if is_transposed:
                    param = param.T

382
383
                shard_size = param.shape[0] // 2
                loaded_weight = loaded_weight[
384
385
386
387
                    shard_size * tensor_model_parallel_rank:shard_size *
                    (tensor_model_parallel_rank + 1)]
                param_slice = param.data[shard_size * stride_id:shard_size *
                                         (stride_id + 1)]
388
389
390
391
392
393
394
395
                assert param_slice.shape == loaded_weight.shape
                param_slice.copy_(loaded_weight)
                is_gate_up_weight = True
                break
            if is_gate_up_weight:
                continue

            param = state_dict[name]
396
397
            if is_transposed:
                param = param.T
JFDuan's avatar
JFDuan committed
398
399
400
401
402
403

            if "embed_tokens" in name or "lm_head" in name:
                load_padded_tensor_parallel_vocab(param, loaded_weight,
                                                  tensor_model_parallel_rank)
                continue

404
            load_tensor_parallel_weights(param, loaded_weight, name,
405
406
                                         column_parallel_weights,
                                         row_parallel_weights,
407
                                         tensor_model_parallel_rank)