llama.py 15.7 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
23
"""Inference-only LLaMA model compatible with HuggingFace weights."""
Roy's avatar
Roy committed
24
from typing import List, Optional, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
25

Roy's avatar
Roy committed
26
import math
Woosuk Kwon's avatar
Woosuk Kwon committed
27
28
29
30
import torch
from torch import nn
from transformers import LlamaConfig

Woosuk Kwon's avatar
Woosuk Kwon committed
31
32
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
Woosuk Kwon's avatar
Woosuk Kwon committed
33
from vllm.model_executor.layers.attention import PagedAttention
34
35
36
37
38
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
                                               MergedColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
Woosuk Kwon's avatar
Woosuk Kwon committed
39
from vllm.model_executor.layers.rotary_embedding import get_rope
Woosuk Kwon's avatar
Woosuk Kwon committed
40
from vllm.model_executor.layers.sampler import Sampler
41
from vllm.model_executor.layers.vocab_parallel_embedding import (
42
    VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE)
Woosuk Kwon's avatar
Woosuk Kwon committed
43
from vllm.model_executor.parallel_utils.parallel_state import (
Roy's avatar
Roy committed
44
    get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
45
from vllm.model_executor.sampling_metadata import SamplingMetadata
46
47
from vllm.model_executor.weight_utils import (default_weight_loader,
                                              hf_model_weights_iterator)
48
from vllm.sequence import SamplerOutput
49
from vllm.config import LoRAConfig
Woosuk Kwon's avatar
Woosuk Kwon committed
50

Roy's avatar
Roy committed
51
52
from copy import deepcopy

Woosuk Kwon's avatar
Woosuk Kwon committed
53
54
55
KVCache = Tuple[torch.Tensor, torch.Tensor]


Roy's avatar
Roy committed
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
    closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
    base = torch.tensor(
        2**(-(2**-(math.log2(closest_power_of_2) - 3))),
        dtype=torch.float32,
    )
    powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
    slopes = torch.pow(base, powers)

    if closest_power_of_2 != total_num_heads:
        extra_base = torch.tensor(
            2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
            dtype=torch.float32,
        )
        num_remaining_heads = min(closest_power_of_2,
                                  total_num_heads - closest_power_of_2)
        extra_powers = torch.arange(start=1,
                                    end=1 + 2 * num_remaining_heads,
                                    step=2,
                                    dtype=torch.int32)
        slopes = torch.cat(
            [slopes, torch.pow(extra_base, extra_powers)], dim=0)
    return slopes


Woosuk Kwon's avatar
Woosuk Kwon committed
81
class LlamaMLP(nn.Module):
82

Woosuk Kwon's avatar
Woosuk Kwon committed
83
84
    def __init__(
        self,
Roy's avatar
Roy committed
85
        config: LlamaConfig,
86
        linear_method: Optional[LinearMethodBase] = None,
87
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
88
        super().__init__()
89
        self.gate_up_proj = MergedColumnParallelLinear(
Roy's avatar
Roy committed
90
            config.hidden_size, [config.intermediate_size] * 2,
91
92
            bias=False,
            linear_method=linear_method)
Roy's avatar
Roy committed
93
94
        self.down_proj = RowParallelLinear(config.intermediate_size,
                                           config.hidden_size,
95
96
                                           bias=False,
                                           linear_method=linear_method)
Roy's avatar
Roy committed
97
        hidden_act = getattr(config, "hidden_act", "silu")
98
99
100
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
Woosuk Kwon's avatar
Woosuk Kwon committed
101
        self.act_fn = SiluAndMul()
Woosuk Kwon's avatar
Woosuk Kwon committed
102
103

    def forward(self, x):
104
        gate_up, _ = self.gate_up_proj(x)
Woosuk Kwon's avatar
Woosuk Kwon committed
105
        x = self.act_fn(gate_up)
Woosuk Kwon's avatar
Woosuk Kwon committed
106
107
108
109
110
111
112
113
        x, _ = self.down_proj(x)
        return x


class LlamaAttention(nn.Module):

    def __init__(
        self,
Roy's avatar
Roy committed
114
        config: LlamaConfig,
115
        linear_method: Optional[LinearMethodBase] = None,
116
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
117
        super().__init__()
Roy's avatar
Roy committed
118
        self.hidden_size = config.hidden_size
Zhuohan Li's avatar
Zhuohan Li committed
119
        tp_size = get_tensor_model_parallel_world_size()
Roy's avatar
Roy committed
120
        self.total_num_heads = getattr(config, "num_attention_heads", None)
Zhuohan Li's avatar
Zhuohan Li committed
121
122
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
Roy's avatar
Roy committed
123
124
125
126

        # defaut to mha
        self.total_num_kv_heads = getattr(config, "num_key_value_heads",
                                          self.total_num_heads)
127
128
129
130
131
132
133
134
135
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
Roy's avatar
Roy committed
136
        self.head_dim = self.hidden_size // self.total_num_heads
Zhuohan Li's avatar
Zhuohan Li committed
137
138
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
139
        self.scaling = self.head_dim**-0.5
Roy's avatar
Roy committed
140
141
142
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
        self.max_position_embeddings = config.max_position_embeddings
Woosuk Kwon's avatar
Woosuk Kwon committed
143

Roy's avatar
Roy committed
144
145
146
147
148
        # internlm
        bias = getattr(config, "bias", False)

        # stablelm
        qkv_bias = getattr(config, "use_qkv_bias", False)
149
        self.qkv_proj = QKVParallelLinear(
Roy's avatar
Roy committed
150
            self.hidden_size,
Zhuohan Li's avatar
Zhuohan Li committed
151
            self.head_dim,
152
153
            self.total_num_heads,
            self.total_num_kv_heads,
Roy's avatar
Roy committed
154
            bias=bias or qkv_bias,
155
            linear_method=linear_method,
Woosuk Kwon's avatar
Woosuk Kwon committed
156
        )
157
        self.o_proj = RowParallelLinear(
Woosuk Kwon's avatar
Woosuk Kwon committed
158
            self.total_num_heads * self.head_dim,
Roy's avatar
Roy committed
159
160
            self.hidden_size,
            bias=bias,
161
            linear_method=linear_method,
Woosuk Kwon's avatar
Woosuk Kwon committed
162
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
163

Roy's avatar
Roy committed
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
        # mistral
        sliding_window = getattr(config, "sliding_window", None)

        self.postion_embedding = getattr(config, "postion_embedding", "ROPE")
        # Create the alibi slopes and slice them.
        if self.postion_embedding == "ALIBI":
            tp_rank = get_tensor_model_parallel_rank()
            head_start = tp_rank * self.num_heads
            head_end = (tp_rank + 1) * self.num_heads
            alibi_slopes = _get_alibi_slopes(self.total_num_heads)
            alibi_slopes = alibi_slopes[head_start:head_end].tolist()

            self.attn = PagedAttention(self.num_heads,
                                       self.head_dim,
                                       self.scaling,
                                       alibi_slopes=alibi_slopes,
                                       sliding_window=sliding_window)
        else:
            rope_theta = getattr(config, "rope_theta", 10000)
            rope_scaling = getattr(config, "rope_scaling", None)
            # stablelm
            rope_pct = getattr(config, "rope_pct", 1)
            self.rotary_emb = get_rope(
                self.head_dim,
                rotary_dim=int(self.head_dim * rope_pct),
                max_position=max_position_embeddings,
                base=rope_theta,
                rope_scaling=rope_scaling,
            )
            self.attn = PagedAttention(self.num_heads,
                                       self.head_dim,
                                       self.scaling,
                                       num_kv_heads=self.num_kv_heads,
                                       sliding_window=sliding_window)
Woosuk Kwon's avatar
Woosuk Kwon committed
198
199
200

    def forward(
        self,
201
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
202
203
204
205
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
    ) -> torch.Tensor:
206
        qkv, _ = self.qkv_proj(hidden_states)
Zhuohan Li's avatar
Zhuohan Li committed
207
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
Roy's avatar
Roy committed
208
209
        if self.postion_embedding != "ALIBI":
            q, k = self.rotary_emb(positions, q, k)
210
        k_cache, v_cache = kv_cache
211
        attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
212
213
214
215
216
217
        output, _ = self.o_proj(attn_output)
        return output


class LlamaDecoderLayer(nn.Module):

218
219
220
    def __init__(
        self,
        config: LlamaConfig,
221
        linear_method: Optional[LinearMethodBase] = None,
Roy's avatar
Roy committed
222
        norm: Optional[torch.Tensor] = None,
223
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
224
225
226
        super().__init__()
        self.hidden_size = config.hidden_size
        self.self_attn = LlamaAttention(
Roy's avatar
Roy committed
227
            config,
228
            linear_method=linear_method,
Woosuk Kwon's avatar
Woosuk Kwon committed
229
230
        )
        self.mlp = LlamaMLP(
Roy's avatar
Roy committed
231
            config,
232
            linear_method=linear_method,
Woosuk Kwon's avatar
Woosuk Kwon committed
233
        )
Roy's avatar
Roy committed
234
235
        self.input_layernorm = deepcopy(norm)
        self.post_attention_layernorm = deepcopy(norm)
Woosuk Kwon's avatar
Woosuk Kwon committed
236
237
238

    def forward(
        self,
239
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
240
241
242
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
243
244
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
245
        # Self Attention
246
247
248
249
250
251
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
252
253
254
255
256
257
258
259
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
            input_metadata=input_metadata,
        )

        # Fully Connected
260
261
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
262
        hidden_states = self.mlp(hidden_states)
263
        return hidden_states, residual
Woosuk Kwon's avatar
Woosuk Kwon committed
264
265
266
267


class LlamaModel(nn.Module):

268
269
270
    def __init__(
        self,
        config: LlamaConfig,
271
        linear_method: Optional[LinearMethodBase] = None,
Roy's avatar
Roy committed
272
        norm: Optional[torch.Tensor] = None,
273
        lora_config: Optional[LoRAConfig] = None,
274
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
275
276
277
        super().__init__()
        self.config = config
        self.padding_idx = config.pad_token_id
278
279
280
281
        lora_vocab = (lora_config.lora_extra_vocab_size *
                      (lora_config.max_loras or 1)) if lora_config else 0
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size
282
        self.embed_tokens = VocabParallelEmbedding(
283
            self.vocab_size,
284
            config.hidden_size,
285
            org_num_embeddings=config.vocab_size,
286
        )
287
        self.layers = nn.ModuleList([
Roy's avatar
Roy committed
288
            LlamaDecoderLayer(config, linear_method, norm)
289
            for _ in range(config.num_hidden_layers)
290
        ])
Roy's avatar
Roy committed
291
        self.norm = norm
Woosuk Kwon's avatar
Woosuk Kwon committed
292
293
294

    def forward(
        self,
295
296
        input_ids: torch.Tensor,
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
297
298
299
300
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
    ) -> torch.Tensor:
        hidden_states = self.embed_tokens(input_ids)
301
        residual = None
Woosuk Kwon's avatar
Woosuk Kwon committed
302
303
        for i in range(len(self.layers)):
            layer = self.layers[i]
304
            hidden_states, residual = layer(
Woosuk Kwon's avatar
Woosuk Kwon committed
305
306
307
308
                positions,
                hidden_states,
                kv_caches[i],
                input_metadata,
309
                residual,
Woosuk Kwon's avatar
Woosuk Kwon committed
310
            )
311
        hidden_states, _ = self.norm(hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
312
313
314
315
        return hidden_states


class LlamaForCausalLM(nn.Module):
316
    supports_lora = True
317

318
319
320
    def __init__(
        self,
        config: LlamaConfig,
321
        linear_method: Optional[LinearMethodBase] = None,
Roy's avatar
Roy committed
322
        norm: Optional[torch.Tensor] = None,
323
        lora_config: Optional[LoRAConfig] = None,
324
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
325
326
        super().__init__()
        self.config = config
327
        self.linear_method = linear_method
Roy's avatar
Roy committed
328
329
330
331
332
333
        if norm is None:
            norm = RMSNorm(config.hidden_size, config.rms_norm_eps)
        self.model = LlamaModel(config,
                                linear_method,
                                norm=norm,
                                lora_config=lora_config)
334
335
336
337
338
339
340
341
342
343
344
345
346
        unpadded_vocab_size = config.vocab_size
        if lora_config:
            unpadded_vocab_size += lora_config.lora_extra_vocab_size
        self.lm_head = ParallelLMHead(
            unpadded_vocab_size,
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
            padding_size=DEFAULT_VOCAB_PADDING_SIZE
            # We need bigger padding if using lora for kernel
            # compatibility
            if not lora_config else lora_config.lora_vocab_padding_size,
        )
        self.sampler = Sampler(unpadded_vocab_size, config.vocab_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
347
348
349

    def forward(
        self,
350
351
        input_ids: torch.Tensor,
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
352
353
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
354
    ) -> torch.Tensor:
355
        hidden_states = self.model(input_ids, positions, kv_caches,
356
                                   input_metadata)
357
358
359
360
361
362
        return hidden_states

    def sample(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
363
    ) -> Optional[SamplerOutput]:
364
        next_tokens = self.sampler(self.lm_head.weight, hidden_states,
365
                                   sampling_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
366
367
        return next_tokens

368
369
    def load_weights(self,
                     model_name_or_path: str,
370
                     cache_dir: Optional[str] = None,
Jasmond L's avatar
Jasmond L committed
371
372
                     load_format: str = "auto",
                     revision: Optional[str] = None):
373
374
375
376
377
378
379
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
Zhuohan Li's avatar
Zhuohan Li committed
380
        ]
381
        params_dict = dict(self.named_parameters())
382
        for name, loaded_weight in hf_model_weights_iterator(
Jasmond L's avatar
Jasmond L committed
383
                model_name_or_path, cache_dir, load_format, revision):
384
385
            if "rotary_emb.inv_freq" in name:
                continue
386
387
388
389
            if ("rotary_emb.cos_cached" in name
                    or "rotary_emb.sin_cached" in name):
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
390
                continue
391
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
Zhuohan Li's avatar
Zhuohan Li committed
392
                if weight_name not in name:
393
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
394
395
396
397
398
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
399
400
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
401
                break
402
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
403
404
405
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
406
407
408
409
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)