llama.py 19.6 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
23
"""Inference-only LLaMA model compatible with HuggingFace weights."""
24
from typing import Any, Dict, Iterable, List, Optional, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
25
26
27
28

import torch
from torch import nn
from transformers import LlamaConfig
zhuwenwen's avatar
zhuwenwen committed
29
import os
gaoqiong's avatar
gaoqiong committed
30
import re
Woosuk Kwon's avatar
Woosuk Kwon committed
31

32
from vllm.attention import Attention, AttentionMetadata
33
from vllm.config import CacheConfig, LoRAConfig
34
35
from vllm.distributed import (get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
36
from vllm.model_executor.layers.activation import SiluAndMul
37
from vllm.model_executor.layers.layernorm import RMSNorm
38
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
39
40
                                               QKVParallelLinear,
                                               RowParallelLinear)
41
from vllm.model_executor.layers.logits_processor import LogitsProcessor
42
43
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
44
from vllm.model_executor.layers.rotary_embedding import get_rope
Woosuk Kwon's avatar
Woosuk Kwon committed
45
from vllm.model_executor.layers.sampler import Sampler
46
from vllm.model_executor.layers.vocab_parallel_embedding import (
47
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
48
49
from vllm.model_executor.model_loader.weight_utils import (
    default_weight_loader, kv_cache_scales_loader)
50
from vllm.model_executor.sampling_metadata import SamplingMetadata
51
from vllm.sequence import SamplerOutput
52
from vllm.utils import is_hip, print_warning_once
Woosuk Kwon's avatar
Woosuk Kwon committed
53

gaoqiong's avatar
gaoqiong committed
54
from vllm import _custom_ops as ops
Woosuk Kwon's avatar
Woosuk Kwon committed
55
56

class LlamaMLP(nn.Module):
57

Woosuk Kwon's avatar
Woosuk Kwon committed
58
59
    def __init__(
        self,
60
61
62
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
63
        quant_config: Optional[QuantizationConfig] = None,
64
        bias: bool = False,
65
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
66
        super().__init__()
67
        self.gate_up_proj = MergedColumnParallelLinear(
68
69
            input_size=hidden_size,
            output_sizes=[intermediate_size] * 2,
70
            bias=bias,
71
            quant_config=quant_config)
72
73
        self.down_proj = RowParallelLinear(input_size=intermediate_size,
                                           output_size=hidden_size,
74
                                           bias=bias,
75
                                           quant_config=quant_config)
76
77
78
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
Woosuk Kwon's avatar
Woosuk Kwon committed
79
        self.act_fn = SiluAndMul()
Woosuk Kwon's avatar
Woosuk Kwon committed
80
81

    def forward(self, x):
82
        gate_up, _ = self.gate_up_proj(x)
Woosuk Kwon's avatar
Woosuk Kwon committed
83
        x = self.act_fn(gate_up)
Woosuk Kwon's avatar
Woosuk Kwon committed
84
85
86
87
88
89
90
91
        x, _ = self.down_proj(x)
        return x


class LlamaAttention(nn.Module):

    def __init__(
        self,
92
93
94
95
96
97
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
        max_position_embeddings: int = 8192,
98
        quant_config: Optional[QuantizationConfig] = None,
99
        bias: bool = False,
100
        cache_config: Optional[CacheConfig] = None,
101
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
102
        super().__init__()
103
        self.hidden_size = hidden_size
Zhuohan Li's avatar
Zhuohan Li committed
104
        tp_size = get_tensor_model_parallel_world_size()
105
        self.total_num_heads = num_heads
Zhuohan Li's avatar
Zhuohan Li committed
106
107
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
108
        self.total_num_kv_heads = num_kv_heads
109
110
111
112
113
114
115
116
117
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
118
        self.head_dim = hidden_size // self.total_num_heads
Zhuohan Li's avatar
Zhuohan Li committed
119
120
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
121
        self.scaling = self.head_dim**-0.5
122
123
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings
Woosuk Kwon's avatar
Woosuk Kwon committed
124

125
        self.qkv_proj = QKVParallelLinear(
126
127
128
129
            hidden_size=hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_kv_heads,
130
            bias=bias,
131
            quant_config=quant_config,
Woosuk Kwon's avatar
Woosuk Kwon committed
132
        )
133
        self.o_proj = RowParallelLinear(
134
135
            input_size=self.total_num_heads * self.head_dim,
            output_size=hidden_size,
136
            bias=bias,
137
            quant_config=quant_config,
Woosuk Kwon's avatar
Woosuk Kwon committed
138
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
139

140
141
142
143
144
145
146
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
147
148
149
150
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
                              num_kv_heads=self.num_kv_heads,
151
152
                              cache_config=cache_config,
                              quant_config=quant_config)
Woosuk Kwon's avatar
Woosuk Kwon committed
153
154
155

    def forward(
        self,
156
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
157
        hidden_states: torch.Tensor,
158
159
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
160
    ) -> torch.Tensor:
161
        qkv, _ = self.qkv_proj(hidden_states)
zhuwenwen's avatar
zhuwenwen committed
162
        if os.environ.get('FA_PAD') == '1' and qkv.shape[-1] == 12320:
zhuwenwen's avatar
zhuwenwen committed
163
            qkv = qkv[...,:-32]
Zhuohan Li's avatar
Zhuohan Li committed
164
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
165
        q, k = self.rotary_emb(positions, q, k)
166
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
167
168
169
170
171
172
        output, _ = self.o_proj(attn_output)
        return output


class LlamaDecoderLayer(nn.Module):

173
174
175
    def __init__(
        self,
        config: LlamaConfig,
176
        cache_config: Optional[CacheConfig] = None,
177
        quant_config: Optional[QuantizationConfig] = None,
178
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
179
180
        super().__init__()
        self.hidden_size = config.hidden_size
181
182
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
183
184
185
186
        if rope_scaling is not None and getattr(
                config, "original_max_position_embeddings", None):
            rope_scaling["original_max_position_embeddings"] = (
                config.original_max_position_embeddings)
187
188
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
189
190
191
192
        # Support abacusai/Smaug-72B-v0.1 with attention_bias
        # Support internlm/internlm-7b with bias
        attention_bias = getattr(config, "attention_bias", False) or getattr(
            config, "bias", False)
Woosuk Kwon's avatar
Woosuk Kwon committed
193
        self.self_attn = LlamaAttention(
194
195
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
196
197
            num_kv_heads=getattr(config, "num_key_value_heads",
                                 config.num_attention_heads),
198
199
200
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
201
            quant_config=quant_config,
202
            bias=attention_bias,
203
            cache_config=cache_config,
Woosuk Kwon's avatar
Woosuk Kwon committed
204
205
        )
        self.mlp = LlamaMLP(
206
207
208
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
209
            quant_config=quant_config,
210
            bias=getattr(config, "mlp_bias", False),
Woosuk Kwon's avatar
Woosuk Kwon committed
211
        )
212
213
214
215
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)
Woosuk Kwon's avatar
Woosuk Kwon committed
216
217
218

    def forward(
        self,
219
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
220
        hidden_states: torch.Tensor,
221
222
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
223
224
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
225
        # Self Attention
226
227
228
229
230
231
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
232
233
234
235
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
236
            attn_metadata=attn_metadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
237
238
239
        )

        # Fully Connected
240
241
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
242
        hidden_states = self.mlp(hidden_states)
243
        return hidden_states, residual
Woosuk Kwon's avatar
Woosuk Kwon committed
244
245
246
247


class LlamaModel(nn.Module):

248
249
250
    def __init__(
        self,
        config: LlamaConfig,
251
        cache_config: Optional[CacheConfig] = None,
252
        quant_config: Optional[QuantizationConfig] = None,
253
        lora_config: Optional[LoRAConfig] = None,
254
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
255
256
257
        super().__init__()
        self.config = config
        self.padding_idx = config.pad_token_id
258
259
260
261
        lora_vocab = (lora_config.lora_extra_vocab_size *
                      (lora_config.max_loras or 1)) if lora_config else 0
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size
262
        self.embed_tokens = VocabParallelEmbedding(
263
            self.vocab_size,
264
            config.hidden_size,
265
            org_num_embeddings=config.vocab_size,
266
        )
267
        self.layers = nn.ModuleList([
268
269
270
271
            LlamaDecoderLayer(config=config,
                              cache_config=cache_config,
                              quant_config=quant_config)
            for idx in range(config.num_hidden_layers)
272
        ])
273
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Woosuk Kwon's avatar
Woosuk Kwon committed
274

275
276
277
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

Woosuk Kwon's avatar
Woosuk Kwon committed
278
279
    def forward(
        self,
280
        input_ids: Optional[torch.Tensor],
281
        positions: torch.Tensor,
282
283
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
284
        inputs_embeds: Optional[torch.Tensor] = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
285
    ) -> torch.Tensor:
286
287
288
289
        if inputs_embeds is not None:
            hidden_states = inputs_embeds
        else:
            hidden_states = self.get_input_embeddings(input_ids)
290
        residual = None
Woosuk Kwon's avatar
Woosuk Kwon committed
291
292
        for i in range(len(self.layers)):
            layer = self.layers[i]
293
            hidden_states, residual = layer(
Woosuk Kwon's avatar
Woosuk Kwon committed
294
295
296
                positions,
                hidden_states,
                kv_caches[i],
297
                attn_metadata,
298
                residual,
Woosuk Kwon's avatar
Woosuk Kwon committed
299
            )
300
        hidden_states, _ = self.norm(hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
301
302
303
304
        return hidden_states


class LlamaForCausalLM(nn.Module):
Terry's avatar
Terry committed
305
306
307
308
309
310
311
312
313
314
315
316
317
318
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

    # LoRA specific attributes
    supported_lora_modules = [
319
320
        "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens",
        "lm_head"
Terry's avatar
Terry committed
321
322
323
324
325
326
    ]
    embedding_modules = {
        "embed_tokens": "input_embeddings",
        "lm_head": "output_embeddings",
    }
    embedding_padding_modules = ["lm_head"]
327
328
329
330
331
332
333
334
    bitsandbytes_stacked_params_mapping = {
        # shard_name, weight_name, index
        "q_proj": ("qkv_proj", 0),
        "k_proj": ("qkv_proj", 1),
        "v_proj": ("qkv_proj", 2),
        "gate_proj": ("gate_up_proj", 0),
        "up_proj": ("gate_up_proj", 1),
    }
335

336
337
338
    def __init__(
        self,
        config: LlamaConfig,
339
        cache_config: Optional[CacheConfig] = None,
340
        quant_config: Optional[QuantizationConfig] = None,
341
        lora_config: Optional[LoRAConfig] = None,
342
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
343
344
        super().__init__()
        self.config = config
345
346
347
348
        self.model = LlamaModel(config,
                                cache_config,
                                quant_config,
                                lora_config=lora_config)
Terry's avatar
Terry committed
349
        self.unpadded_vocab_size = config.vocab_size
350
        if lora_config:
Terry's avatar
Terry committed
351
            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
352
        self.lm_head = ParallelLMHead(
Terry's avatar
Terry committed
353
            self.unpadded_vocab_size,
354
355
356
357
358
359
360
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
            padding_size=DEFAULT_VOCAB_PADDING_SIZE
            # We need bigger padding if using lora for kernel
            # compatibility
            if not lora_config else lora_config.lora_vocab_padding_size,
        )
361
362
        if config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
363
364
365
366
367

        logit_scale = getattr(config, "logit_scale", 1.0)
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size, logit_scale)
        self.sampler = Sampler()
gaoqiong's avatar
gaoqiong committed
368
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
Woosuk Kwon's avatar
Woosuk Kwon committed
369
370
371

    def forward(
        self,
372
373
        input_ids: torch.Tensor,
        positions: torch.Tensor,
374
375
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
376
    ) -> torch.Tensor:
377
        hidden_states = self.model(input_ids, positions, kv_caches,
378
                                   attn_metadata)
379
380
        return hidden_states

381
382
383
384
385
386
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head.weight, hidden_states,
                                       sampling_metadata)
        return logits

387
388
    def sample(
        self,
389
        logits: torch.Tensor,
390
        sampling_metadata: SamplingMetadata,
391
    ) -> Optional[SamplerOutput]:
392
        next_tokens = self.sampler(logits, sampling_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
393
394
        return next_tokens

395
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
396
397
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
398
399
400
401
402
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
Zhuohan Li's avatar
Zhuohan Li committed
403
        ]
404
        params_dict = dict(self.named_parameters())
405
        for name, loaded_weight in weights:
406
407
            if "rotary_emb.inv_freq" in name:
                continue
408
409
410
411
            if ("rotary_emb.cos_cached" in name
                    or "rotary_emb.sin_cached" in name):
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
412
                continue
413
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
Zhuohan Li's avatar
Zhuohan Li committed
414
                if weight_name not in name:
415
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
416
417
418
419
420
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
421
422
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
423
                break
424
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
425
426
427
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
428
429
430
431
432
433
434
435
436
437
438
439
440
                # Remapping the name of FP8 kv-scale.
                if name.endswith("kv_scale"):
                    remapped_kv_scale_name = name.replace(
                        ".kv_scale", ".attn.kv_scale")
                    if remapped_kv_scale_name not in params_dict:
                        print_warning_once(
                            f"Found kv scale in the checkpoint (e.g. {name}), "
                            "but not found the expected name in the model "
                            f"(e.g. {remapped_kv_scale_name}). kv-scale is "
                            "not loaded.")
                        continue
                    else:
                        name = remapped_kv_scale_name
441
442
443
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
gaoqiong's avatar
gaoqiong committed
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
                weight_loader(param, loaded_weight)  
            
        if self.use_llama_nn:
            #以上代码模型权重已经加载完了
            #以下代码使用正则匹配来找出要修改的weight
            lay_key_words = [
                "self_attn.qkv_proj.weight",
                "self_attn.o_proj.weight",
                "mlp.gate_up_proj.weight",
                "mlp.down_proj.weight"
            ]
            #合并所有关键词为一个正则表达式
            combined_words = "|".join(lay_key_words)
            
            for layername, weight in params_dict.items():
                #print("key:\n",key)
                matches = re.findall(combined_words, layername)
                if matches:                    
                    #创建一个跟value一样大的tensor
                    _weight = torch.zeros_like(weight.data)
                    ori_shape =_weight.shape
                    
                    ops.trans_w16_gemm(_weight,weight.data,_weight.shape[0],_weight.shape[1])
                    weight.data.copy_(_weight)
                    
                    weight.data=weight.data.reshape(ori_shape[1],-1)
                    
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
    # If this function is called, it should always initialize KV cache scale
    # factors (or else raise an exception). Thus, handled exceptions should
    # make sure to leave KV cache scale factors in a known good (dummy) state
    def load_kv_cache_scales(self, quantization_param_path: str) -> None:
        tp_size = get_tensor_model_parallel_world_size()
        tp_rank = get_tensor_model_parallel_rank()
        for layer_idx, scaling_factor in kv_cache_scales_loader(
                quantization_param_path, tp_rank, tp_size,
                self.config.num_hidden_layers,
                self.config.__class__.model_type):
            layer_self_attn = self.model.layers[layer_idx].self_attn

            if is_hip():
                # The scaling factor convention we are assuming is
                # quantized_value * scaling_factor ~= true_value
                # which is consistent with the practice of setting
                # scaling_factor = tensor_amax / FPtype_max
                scaling_factor *= 2
            if hasattr(layer_self_attn, "kv_scale"):
490
                layer_self_attn.attn._kv_scale = scaling_factor
491
492
493
            else:
                raise RuntimeError("Self attention has no KV cache scaling "
                                   "factor attribute!")