llama.py 16.7 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
23
24
25
26
27
"""Inference-only LLaMA model compatible with HuggingFace weights.

The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
28
from typing import Any, Dict, List, Optional, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
29
30
31
32
33

import torch
from torch import nn
from transformers import LlamaConfig

Woosuk Kwon's avatar
Woosuk Kwon committed
34
35
36
37
38
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
from vllm.model_executor.layers.sampler import Sampler
39
from vllm.model_executor.layers.quantized_linear import ParallelLinear
Woosuk Kwon's avatar
Woosuk Kwon committed
40
from vllm.model_executor.parallel_utils.parallel_state import (
Woosuk Kwon's avatar
Woosuk Kwon committed
41
    get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
42
from vllm.model_executor.parallel_utils.layers import VocabParallelEmbedding
43
44
from vllm.model_executor.quantization_utils import QuantizationConfig
from vllm.model_executor.weight_utils import (
Woosuk Kwon's avatar
Woosuk Kwon committed
45
46
    convert_pyslice_to_tensor, hf_model_weights_iterator,
    load_tensor_parallel_weights, load_padded_tensor_parallel_vocab)
47
from vllm.sequence import SamplerOutput
Woosuk Kwon's avatar
Woosuk Kwon committed
48
49
50
51
52

KVCache = Tuple[torch.Tensor, torch.Tensor]


class LlamaMLP(nn.Module):
53

Woosuk Kwon's avatar
Woosuk Kwon committed
54
55
56
57
58
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
59
60
        quant_config: Optional[QuantizationConfig] = None,
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
61
        super().__init__()
62
63
64
65
66
67
68
69
70
71
        self.gate_up_proj = ParallelLinear.column(hidden_size,
                                                  2 * intermediate_size,
                                                  bias=False,
                                                  gather_output=False,
                                                  quant_config=quant_config)
        self.down_proj = ParallelLinear.row(intermediate_size,
                                            hidden_size,
                                            bias=False,
                                            input_is_parallel=True,
                                            quant_config=quant_config)
72
73
74
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
Woosuk Kwon's avatar
Woosuk Kwon committed
75
        self.act_fn = SiluAndMul()
Woosuk Kwon's avatar
Woosuk Kwon committed
76
77

    def forward(self, x):
78
        gate_up, _ = self.gate_up_proj(x)
Woosuk Kwon's avatar
Woosuk Kwon committed
79
        x = self.act_fn(gate_up)
Woosuk Kwon's avatar
Woosuk Kwon committed
80
81
82
83
84
85
86
87
88
89
        x, _ = self.down_proj(x)
        return x


class LlamaAttention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
Zhuohan Li's avatar
Zhuohan Li committed
90
        num_kv_heads: int,
Antoni Baum's avatar
Antoni Baum committed
91
        rope_theta: float = 10000,
92
        rope_scaling: Optional[Dict[str, Any]] = None,
93
        max_position_embeddings: int = 8192,
94
95
        quant_config: Optional[QuantizationConfig] = None,
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
96
97
        super().__init__()
        self.hidden_size = hidden_size
Zhuohan Li's avatar
Zhuohan Li committed
98
        tp_size = get_tensor_model_parallel_world_size()
Woosuk Kwon's avatar
Woosuk Kwon committed
99
        self.total_num_heads = num_heads
Zhuohan Li's avatar
Zhuohan Li committed
100
101
102
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
103
104
105
106
107
108
109
110
111
112
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        num_kv_heads_replicas = max(1, tp_size // self.total_num_kv_heads)
Woosuk Kwon's avatar
Woosuk Kwon committed
113
        self.head_dim = hidden_size // self.total_num_heads
Zhuohan Li's avatar
Zhuohan Li committed
114
115
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
116
        self.scaling = self.head_dim**-0.5
Antoni Baum's avatar
Antoni Baum committed
117
        self.rope_theta = rope_theta
118
        self.max_position_embeddings = max_position_embeddings
Woosuk Kwon's avatar
Woosuk Kwon committed
119

120
        self.qkv_proj = ParallelLinear.column(
Woosuk Kwon's avatar
Woosuk Kwon committed
121
            hidden_size,
122
123
            (self.total_num_heads +
             2 * self.total_num_kv_heads * num_kv_heads_replicas) *
Zhuohan Li's avatar
Zhuohan Li committed
124
            self.head_dim,
Woosuk Kwon's avatar
Woosuk Kwon committed
125
126
            bias=False,
            gather_output=False,
127
            quant_config=quant_config,
Woosuk Kwon's avatar
Woosuk Kwon committed
128
        )
129
        self.o_proj = ParallelLinear.row(
Woosuk Kwon's avatar
Woosuk Kwon committed
130
131
132
133
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
            input_is_parallel=True,
134
            quant_config=quant_config,
Woosuk Kwon's avatar
Woosuk Kwon committed
135
        )
136
137
138
139
140
141
142
        self.attn = PagedAttentionWithRoPE(
            self.num_heads,
            self.head_dim,
            self.scaling,
            base=self.rope_theta,
            max_position=self.max_position_embeddings,
            rotary_dim=self.head_dim,
143
144
            num_kv_heads=self.num_kv_heads,
            rope_scaling=rope_scaling)
Woosuk Kwon's avatar
Woosuk Kwon committed
145
146
147

    def forward(
        self,
148
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
149
150
151
152
153
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:
154
        qkv, _ = self.qkv_proj(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
155
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
156
        k_cache, v_cache = kv_cache
157
158
        attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
                                input_metadata, cache_event)
Woosuk Kwon's avatar
Woosuk Kwon committed
159
160
161
162
163
164
        output, _ = self.o_proj(attn_output)
        return output


class LlamaDecoderLayer(nn.Module):

165
166
167
168
169
    def __init__(
        self,
        config: LlamaConfig,
        quant_config: Optional[QuantizationConfig] = None,
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
170
171
        super().__init__()
        self.hidden_size = config.hidden_size
Antoni Baum's avatar
Antoni Baum committed
172
173
        # Requires transformers > 4.32.0
        rope_theta = getattr(config, "rope_theta", 10000)
174
        rope_scaling = getattr(config, "rope_scaling", None)
175
176
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
Woosuk Kwon's avatar
Woosuk Kwon committed
177
178
179
        self.self_attn = LlamaAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
Zhuohan Li's avatar
Zhuohan Li committed
180
            num_kv_heads=config.num_key_value_heads,
Antoni Baum's avatar
Antoni Baum committed
181
            rope_theta=rope_theta,
182
            rope_scaling=rope_scaling,
183
            max_position_embeddings=max_position_embeddings,
184
            quant_config=quant_config,
Woosuk Kwon's avatar
Woosuk Kwon committed
185
186
187
188
189
        )
        self.mlp = LlamaMLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
190
            quant_config=quant_config,
Woosuk Kwon's avatar
Woosuk Kwon committed
191
        )
192
193
194
195
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)
Woosuk Kwon's avatar
Woosuk Kwon committed
196
197
198

    def forward(
        self,
199
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:
        # Self Attention
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
            input_metadata=input_metadata,
            cache_event=cache_event,
        )
        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states
        return hidden_states


class LlamaModel(nn.Module):

227
228
229
230
231
    def __init__(
        self,
        config: LlamaConfig,
        quant_config: Optional[QuantizationConfig] = None,
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
232
233
234
235
236
        super().__init__()
        self.config = config
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

237
        vocab_size = ((config.vocab_size + 63) // 64) * 64
238
        self.embed_tokens = VocabParallelEmbedding(
239
240
241
            vocab_size,
            config.hidden_size,
        )
242
        self.layers = nn.ModuleList([
243
244
            LlamaDecoderLayer(config, quant_config)
            for _ in range(config.num_hidden_layers)
245
        ])
246
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Woosuk Kwon's avatar
Woosuk Kwon committed
247
248
249

    def forward(
        self,
250
251
        input_ids: torch.Tensor,
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
    ) -> torch.Tensor:
        hidden_states = self.embed_tokens(input_ids)
        for i in range(len(self.layers)):
            if cache_events is None:
                cache_event = None
            else:
                cache_event = cache_events[i]
            layer = self.layers[i]
            hidden_states = layer(
                positions,
                hidden_states,
                kv_caches[i],
                input_metadata,
                cache_event,
            )
        hidden_states = self.norm(hidden_states)
        return hidden_states


class LlamaForCausalLM(nn.Module):
275

276
277
278
279
280
    def __init__(
        self,
        config: LlamaConfig,
        quant_config: Optional[QuantizationConfig] = None,
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
281
282
        super().__init__()
        self.config = config
283
284
        self.quant_config = quant_config
        self.model = LlamaModel(config, quant_config)
285
        vocab_size = ((config.vocab_size + 63) // 64) * 64
286
287
288
289
290
291
        # NOTE: The LM head is not quantized.
        self.lm_head = ParallelLinear.column(config.hidden_size,
                                             vocab_size,
                                             bias=False,
                                             gather_output=False,
                                             quant_config=None)
Woosuk Kwon's avatar
Woosuk Kwon committed
292
        self.sampler = Sampler(config.vocab_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
293
294
295

    def forward(
        self,
296
297
        input_ids: torch.Tensor,
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
298
299
300
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
301
    ) -> SamplerOutput:
302
303
304
305
        hidden_states = self.model(input_ids, positions, kv_caches,
                                   input_metadata, cache_events)
        next_tokens = self.sampler(self.lm_head.weight, hidden_states,
                                   input_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
306
307
        return next_tokens

308
309
    _column_parallel_layers = []
    _row_parallel_layers = ["o_proj", "down_proj"]
Woosuk Kwon's avatar
Woosuk Kwon committed
310

311
312
    def load_weights(self,
                     model_name_or_path: str,
313
                     cache_dir: Optional[str] = None,
Jasmond L's avatar
Jasmond L committed
314
315
                     load_format: str = "auto",
                     revision: Optional[str] = None):
316
        if self.quant_config is None:
chooper1's avatar
chooper1 committed
317
318
            col_weight_suffixes = ["weight"]
            row_weight_suffixes = ["weight"]
319
        else:
chooper1's avatar
chooper1 committed
320
321
322
323
            col_weight_suffixes = (
                self.quant_config.get_col_parallel_tensor_names())
            row_weight_suffixes = (
                self.quant_config.get_row_parallel_tensor_names())
324
325
326

        column_parallel_weights: List[str] = []
        for layer in self._column_parallel_layers:
chooper1's avatar
chooper1 committed
327
            for suffix in col_weight_suffixes:
328
329
330
                column_parallel_weights.append(f"{layer}.{suffix}")
        row_parallel_weights: List[str] = []
        for layer in self._row_parallel_layers:
chooper1's avatar
chooper1 committed
331
            for suffix in row_weight_suffixes:
332
333
                row_parallel_weights.append(f"{layer}.{suffix}")

Zhuohan Li's avatar
Zhuohan Li committed
334
        tp_size = get_tensor_model_parallel_world_size()
335
        tp_rank = get_tensor_model_parallel_rank()
Zhuohan Li's avatar
Zhuohan Li committed
336
        q_proj_shard_size = (self.config.hidden_size // tp_size)
337
338
339
340
        num_kv_heads_replicas = max(1,
                                    tp_size // self.config.num_key_value_heads)
        num_kv_heads_per_gpu = max(1,
                                   self.config.num_key_value_heads // tp_size)
Zhuohan Li's avatar
Zhuohan Li committed
341
342
        kv_proj_shard_size = (self.config.hidden_size //
                              self.config.num_attention_heads *
343
                              num_kv_heads_per_gpu)
Zhuohan Li's avatar
Zhuohan Li committed
344
345
346
347
348
349
350
        attention_weight_specs = [
            # (weight_name, shard_size, offset)
            ("q_proj", q_proj_shard_size, 0),
            ("k_proj", kv_proj_shard_size, q_proj_shard_size),
            ("v_proj", kv_proj_shard_size,
             q_proj_shard_size + kv_proj_shard_size),
        ]
Woosuk Kwon's avatar
Woosuk Kwon committed
351
        state_dict = self.state_dict()
352
353

        for name, loaded_weight in hf_model_weights_iterator(
Jasmond L's avatar
Jasmond L committed
354
                model_name_or_path, cache_dir, load_format, revision):
355
356
357
            if "rotary_emb.inv_freq" in name:
                continue

chooper1's avatar
chooper1 committed
358
            packed_dim = None
359
360
            is_transposed = False
            if self.quant_config is not None:
chooper1's avatar
chooper1 committed
361
                packed_dim = self.quant_config.get_packed_dim(name)
362
363
                is_transposed = self.quant_config.is_transposed(name)
            if is_transposed:
Woosuk Kwon's avatar
Woosuk Kwon committed
364
                loaded_weight = convert_pyslice_to_tensor(loaded_weight)
365
366
                loaded_weight = loaded_weight.T

367
            is_attention_weight = False
Zhuohan Li's avatar
Zhuohan Li committed
368
369
            for weight_name, shard_size, offset in attention_weight_specs:
                if weight_name not in name:
370
                    continue
Zhuohan Li's avatar
Zhuohan Li committed
371
                param = state_dict[name.replace(weight_name, "qkv_proj")]
372
373
374
                if is_transposed:
                    param = param.T

chooper1's avatar
chooper1 committed
375
376
377
378
379
                if packed_dim is not None:
                    shard_dim = 0 if not is_transposed else 1
                    if packed_dim == shard_dim:
                        shard_size //= self.quant_config.pack_factor
                        offset //= self.quant_config.pack_factor
Zhuohan Li's avatar
Zhuohan Li committed
380

381
382
383
384
385
386
387
                if weight_name in ["k_proj", "v_proj"]:
                    shard_id = tp_rank // num_kv_heads_replicas
                else:
                    shard_id = tp_rank
                loaded_weight = loaded_weight[shard_size *
                                              shard_id:shard_size *
                                              (shard_id + 1)]
Zhuohan Li's avatar
Zhuohan Li committed
388
                param_slice = param.data[offset:offset + shard_size]
389
                assert param_slice.shape == loaded_weight.shape
Zhuohan Li's avatar
Zhuohan Li committed
390

391
392
393
394
395
396
397
398
399
400
401
                param_slice.copy_(loaded_weight)
                is_attention_weight = True
                break
            if is_attention_weight:
                continue

            is_gate_up_weight = False
            for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
                if weight_name not in name:
                    continue
                param = state_dict[name.replace(weight_name, "gate_up_proj")]
402
403
404
                if is_transposed:
                    param = param.T

405
                shard_size = param.shape[0] // 2
406
407
                loaded_weight = loaded_weight[shard_size * tp_rank:shard_size *
                                              (tp_rank + 1)]
408
409
                param_slice = param.data[shard_size * stride_id:shard_size *
                                         (stride_id + 1)]
410
411
412
413
414
415
416
417
                assert param_slice.shape == loaded_weight.shape
                param_slice.copy_(loaded_weight)
                is_gate_up_weight = True
                break
            if is_gate_up_weight:
                continue

            param = state_dict[name]
418
419
            if is_transposed:
                param = param.T
JFDuan's avatar
JFDuan committed
420
421
422

            if "embed_tokens" in name or "lm_head" in name:
                load_padded_tensor_parallel_vocab(param, loaded_weight,
423
                                                  tp_rank)
JFDuan's avatar
JFDuan committed
424
425
                continue

426
            load_tensor_parallel_weights(param, loaded_weight, name,
427
                                         column_parallel_weights,
428
                                         row_parallel_weights, tp_rank)