llama.py 30.9 KB
Newer Older
1
2
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
Woosuk Kwon's avatar
Woosuk Kwon committed
3
# Copyright 2023 The vLLM team.
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
22
"""Inference-only LLaMA model compatible with HuggingFace weights."""
23
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
24
25
26
27

import torch
from torch import nn
from transformers import LlamaConfig
zhuwenwen's avatar
zhuwenwen committed
28
import os
gaoqiong's avatar
gaoqiong committed
29
import re
Woosuk Kwon's avatar
Woosuk Kwon committed
30

31
from vllm.attention import Attention, AttentionMetadata
32
from vllm.compilation.decorators import support_torch_compile
33
from vllm.config import CacheConfig, VllmConfig
34
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
Woosuk Kwon's avatar
Woosuk Kwon committed
35
from vllm.model_executor.layers.activation import SiluAndMul
36
from vllm.model_executor.layers.layernorm import RMSNorm
37
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
38
39
                                               QKVParallelLinear,
                                               RowParallelLinear)
40
from vllm.model_executor.layers.logits_processor import LogitsProcessor
41
from vllm.model_executor.layers.quantization import QuantizationConfig
42
from vllm.model_executor.layers.rotary_embedding import get_rope
Joe Runde's avatar
Joe Runde committed
43
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
44
from vllm.model_executor.layers.vocab_parallel_embedding import (
45
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
46
from vllm.model_executor.model_loader.weight_utils import (
47
    default_weight_loader, maybe_remap_kv_scale_name)
48
from vllm.model_executor.sampling_metadata import SamplingMetadata
49
from vllm.sequence import IntermediateTensors
zhuwenwen's avatar
zhuwenwen committed
50
from vllm.utils import W8a8GetCacheJSON
Woosuk Kwon's avatar
Woosuk Kwon committed
51

52
from .interfaces import SupportsLoRA, SupportsPP
53
54
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
                    is_pp_missing_parameter,
55
56
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
57

gaoqiong's avatar
gaoqiong committed
58
from vllm import _custom_ops as ops
59
60
from vllm.model_executor.utils import pad_weight, gemm_bank_conf

Woosuk Kwon's avatar
Woosuk Kwon committed
61
62

class LlamaMLP(nn.Module):
63

Woosuk Kwon's avatar
Woosuk Kwon committed
64
65
    def __init__(
        self,
66
67
68
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
69
        quant_config: Optional[QuantizationConfig] = None,
70
        bias: bool = False,
71
        prefix: str = "",
72
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
73
        super().__init__()
74
        self.gate_up_proj = MergedColumnParallelLinear(
75
76
            input_size=hidden_size,
            output_sizes=[intermediate_size] * 2,
77
            bias=bias,
78
            quant_config=quant_config,
79
80
81
82
83
84
85
86
87
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            input_size=intermediate_size,
            output_size=hidden_size,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.down_proj",
        )
88
89
90
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
Woosuk Kwon's avatar
Woosuk Kwon committed
91
        self.act_fn = SiluAndMul()
Woosuk Kwon's avatar
Woosuk Kwon committed
92
93

    def forward(self, x):
94
95
        x, _ = self.gate_up_proj(x)
        x = self.act_fn(x)
Woosuk Kwon's avatar
Woosuk Kwon committed
96
97
98
99
100
101
        x, _ = self.down_proj(x)
        return x


class LlamaAttention(nn.Module):

102
103
104
105
106
107
108
109
110
111
    def __init__(self,
                 config: LlamaConfig,
                 hidden_size: int,
                 num_heads: int,
                 num_kv_heads: int,
                 rope_theta: float = 10000,
                 rope_scaling: Optional[Dict[str, Any]] = None,
                 max_position_embeddings: int = 8192,
                 quant_config: Optional[QuantizationConfig] = None,
                 bias: bool = False,
112
                 bias_o_proj: bool = False,
113
                 cache_config: Optional[CacheConfig] = None,
114
                 prefix: str = "") -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
115
        super().__init__()
116
        layer_idx = extract_layer_index(prefix)
117
        self.hidden_size = hidden_size
Zhuohan Li's avatar
Zhuohan Li committed
118
        tp_size = get_tensor_model_parallel_world_size()
119
        self.total_num_heads = num_heads
Zhuohan Li's avatar
Zhuohan Li committed
120
121
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
122
        self.total_num_kv_heads = num_kv_heads
123
124
125
126
127
128
129
130
131
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
132
133
134
        # MistralConfig has an optional head_dim introduced by Mistral-Nemo
        self.head_dim = getattr(config, "head_dim",
                                self.hidden_size // self.total_num_heads)
Zhuohan Li's avatar
Zhuohan Li committed
135
136
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
137
        self.scaling = self.head_dim**-0.5
138
139
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings
Woosuk Kwon's avatar
Woosuk Kwon committed
140

141
        self.qkv_proj = QKVParallelLinear(
142
143
144
145
            hidden_size=hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_kv_heads,
146
            bias=bias,
147
            quant_config=quant_config,
148
            prefix=f"{prefix}.qkv_proj",
Woosuk Kwon's avatar
Woosuk Kwon committed
149
        )
150

151
        self.o_proj = RowParallelLinear(
152
153
            input_size=self.total_num_heads * self.head_dim,
            output_size=hidden_size,
154
            bias=bias_o_proj,
155
            quant_config=quant_config,
156
            prefix=f"{prefix}.o_proj",
Woosuk Kwon's avatar
Woosuk Kwon committed
157
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
158

159
        is_neox_style = True
160
161
        is_gguf = quant_config and quant_config.get_name() == "gguf"
        if is_gguf and config.model_type == "llama":
162
163
            is_neox_style = False

164
165
166
167
168
169
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
170
            is_neox_style=is_neox_style,
171
        )
172
173

        if hasattr(config, "interleaved_sliding_window"):
174
175
176
177
178
179
            interleaved_sliding_window = config.interleaved_sliding_window
            if isinstance(interleaved_sliding_window, int):
                sliding_window = interleaved_sliding_window
            elif isinstance(interleaved_sliding_window, list):
                sw_idx = layer_idx % len(interleaved_sliding_window)
                sliding_window = interleaved_sliding_window[sw_idx]
180
            else:
181
182
                raise ValueError(
                    f"{type(interleaved_sliding_window)} is not supported.")
183
184
185
        else:
            sliding_window = None

186
187
188
189
190
191
192
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
193
            per_layer_sliding_window=sliding_window,
194
            prefix=f"{prefix}.attn",
195
        )
196
197
198
199
200
        
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
Woosuk Kwon's avatar
Woosuk Kwon committed
201
202
203

    def forward(
        self,
204
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
205
        hidden_states: torch.Tensor,
206
207
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
208
    ) -> torch.Tensor:
209
        qkv, _ = self.qkv_proj(hidden_states)
zhuwenwen's avatar
zhuwenwen committed
210
211
        # if os.environ.get('FA_PAD') == '1' and self.quant_method is None:
        #     qkv = qkv[...,:-32]
Zhuohan Li's avatar
Zhuohan Li committed
212
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
213
        q, k = self.rotary_emb(positions, q, k)
214
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
215
216
217
218
219
220
        output, _ = self.o_proj(attn_output)
        return output


class LlamaDecoderLayer(nn.Module):

221
222
223
    def __init__(
        self,
        config: LlamaConfig,
224
        cache_config: Optional[CacheConfig] = None,
225
        quant_config: Optional[QuantizationConfig] = None,
226
        prefix: str = "",
227
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
228
229
        super().__init__()
        self.hidden_size = config.hidden_size
230
231
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
232
233
234
235
        if rope_scaling is not None and getattr(
                config, "original_max_position_embeddings", None):
            rope_scaling["original_max_position_embeddings"] = (
                config.original_max_position_embeddings)
236
237
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
238
239
240
241
        # Support abacusai/Smaug-72B-v0.1 with attention_bias
        # Support internlm/internlm-7b with bias
        attention_bias = getattr(config, "attention_bias", False) or getattr(
            config, "bias", False)
242
243
244
245
246
        bias_o_proj = attention_bias
        # support internlm/internlm3-8b with qkv_bias
        if hasattr(config, 'qkv_bias'):
            attention_bias = config.qkv_bias

Woosuk Kwon's avatar
Woosuk Kwon committed
247
        self.self_attn = LlamaAttention(
248
            config=config,
249
250
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
251
252
            num_kv_heads=getattr(config, "num_key_value_heads",
                                 config.num_attention_heads),
253
254
255
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
256
            quant_config=quant_config,
257
            bias=attention_bias,
258
            bias_o_proj=bias_o_proj,
259
            cache_config=cache_config,
260
            prefix=f"{prefix}.self_attn",
Woosuk Kwon's avatar
Woosuk Kwon committed
261
262
        )
        self.mlp = LlamaMLP(
263
264
265
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
266
            quant_config=quant_config,
267
            bias=getattr(config, "mlp_bias", False),
268
            prefix=f"{prefix}.mlp",
Woosuk Kwon's avatar
Woosuk Kwon committed
269
        )
270
271
272
273
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)
Woosuk Kwon's avatar
Woosuk Kwon committed
274
275
276

    def forward(
        self,
277
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
278
        hidden_states: torch.Tensor,
279
280
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
281
282
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
283
        # Self Attention
284
285
286
287
288
289
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
290
291
292
293
        hidden_states = self.self_attn(positions=positions,
                                       hidden_states=hidden_states,
                                       kv_cache=kv_cache,
                                       attn_metadata=attn_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
294
295

        # Fully Connected
296
297
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
298
        hidden_states = self.mlp(hidden_states)
299
        return hidden_states, residual
Woosuk Kwon's avatar
Woosuk Kwon committed
300
301


302
@support_torch_compile
Woosuk Kwon's avatar
Woosuk Kwon committed
303
304
class LlamaModel(nn.Module):

305
306
307
308
309
    def __init__(self,
                 *,
                 vllm_config: VllmConfig,
                 prefix: str = "",
                 layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer):
Woosuk Kwon's avatar
Woosuk Kwon committed
310
        super().__init__()
311
312
313
314
315
316

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config

Woosuk Kwon's avatar
Woosuk Kwon committed
317
        self.config = config
318
        self.quant_config = quant_config
Woosuk Kwon's avatar
Woosuk Kwon committed
319
        self.padding_idx = config.pad_token_id
320
321
322
323
        lora_vocab = (lora_config.lora_extra_vocab_size *
                      (lora_config.max_loras or 1)) if lora_config else 0
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size
324
325
326
327
328
329
        if get_pp_group().is_first_rank or (config.tie_word_embeddings
                                            and get_pp_group().is_last_rank):
            self.embed_tokens = VocabParallelEmbedding(
                self.vocab_size,
                config.hidden_size,
                org_num_embeddings=config.vocab_size,
330
                quant_config=quant_config,
331
332
333
            )
        else:
            self.embed_tokens = PPMissingLayer()
334
        self.start_layer, self.end_layer, self.layers = make_layers(
335
            config.num_hidden_layers,
336
337
338
339
            lambda prefix: layer_type(config=config,
                                      cache_config=cache_config,
                                      quant_config=quant_config,
                                      prefix=prefix),
340
341
            prefix=f"{prefix}.layers",
        )
342
343
344
345
        if get_pp_group().is_last_rank:
            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        else:
            self.norm = PPMissingLayer()
Woosuk Kwon's avatar
Woosuk Kwon committed
346

347
348
349
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
zhuwenwen's avatar
zhuwenwen committed
350
351
352
353
354
355
356
357
358
359
360
        
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
              
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
        # self.use_lm_nn = os.environ.get('LM_NN') == '1'
        self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
        self.use_fa_pad = os.environ.get('FA_PAD') == '1'
        self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
zhuwenwen's avatar
zhuwenwen committed
361
        self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
362

363
364
365
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

Woosuk Kwon's avatar
Woosuk Kwon committed
366
367
    def forward(
        self,
368
        input_ids: Optional[torch.Tensor],
369
        positions: torch.Tensor,
370
371
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
372
        intermediate_tensors: Optional[IntermediateTensors],
373
        inputs_embeds: Optional[torch.Tensor] = None,
374
375
376
377
378
379
380
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
            residual = None
381
        else:
382
383
384
385
386
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

        for i in range(self.start_layer, self.end_layer):
Woosuk Kwon's avatar
Woosuk Kwon committed
387
            layer = self.layers[i]
388
389
390
            hidden_states, residual = layer(positions, hidden_states,
                                            kv_caches[i - self.start_layer],
                                            attn_metadata, residual)
391
392
393
394
395
396
397

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })

398
        hidden_states, _ = self.norm(hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
399
400
        return hidden_states

401
402
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
403
404
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
405
406
407
408
409
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
Zhuohan Li's avatar
Zhuohan Li committed
410
        ]
411
        params_dict = dict(self.named_parameters())
412
        loaded_params: Set[str] = set()
413
        for name, loaded_weight in weights:
414
415
            if "rotary_emb.inv_freq" in name:
                continue
416
417
418
419
            if ("rotary_emb.cos_cached" in name
                    or "rotary_emb.sin_cached" in name):
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
420
                continue
421
422
            if (self.quant_config is not None and
                (scale_name := self.quant_config.get_cache_scale(name))):
423
                # Loading kv cache quantization scales
424
425
426
                param = params_dict[scale_name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
427
428
                loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
                                 loaded_weight[0])
429
                weight_loader(param, loaded_weight)
430
                loaded_params.add(scale_name)
431
                continue
432
433
434
435
436
            if "scale" in name:
                # Remapping the name of FP8 kv-scale.
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue
437
            for param_name, weight_name, shard_id in stacked_params_mapping:
Zhuohan Li's avatar
Zhuohan Li committed
438
                if weight_name not in name:
439
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
440
441
442
443
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
444
445
446
447

                if is_pp_missing_parameter(name, self):
                    continue

CHU Tianxiang's avatar
CHU Tianxiang committed
448
                param = params_dict[name]
449
450
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
451
                break
452
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
453
454
455
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
456
457
458
459

                if is_pp_missing_parameter(name, self):
                    continue

460
461
462
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
463
                weight_loader(param, loaded_weight)
464
            loaded_params.add(name)
gaoqiong's avatar
gaoqiong committed
465
            
466
        if self.use_llama_nn and self.quant_method is None :
gaoqiong's avatar
gaoqiong committed
467
468
469
470
            lay_key_words = [
                "self_attn.qkv_proj.weight",
                "self_attn.o_proj.weight",
                "mlp.gate_up_proj.weight",
471
                "mlp.down_proj.weight",
gaoqiong's avatar
gaoqiong committed
472
473
474
            ]
            combined_words = "|".join(lay_key_words)
            
zhuwenwen's avatar
zhuwenwen committed
475
476
            # lay_qkv_words = ["self_attn.qkv_proj.weight"]   
            # qkv_words = "|".join(lay_qkv_words)          
zhuwenwen's avatar
zhuwenwen committed
477
            
gaoqiong's avatar
gaoqiong committed
478
            for layername, weight in params_dict.items():
zhuwenwen's avatar
zhuwenwen committed
479
                if "lm_head.weight" in layername and weight.shape[1] >= 4096:
480
481
                    lay_key_words.append("lm_head.weight")
                    combined_words = "|".join(lay_key_words)
482
483
484
                    os.environ['LM_NN'] = '1'  
                else:
                    os.environ['LM_NN'] = '0' 
485
                    
gaoqiong's avatar
gaoqiong committed
486
                matches = re.findall(combined_words, layername)
487
                
488
                if matches:         
zhuwenwen's avatar
zhuwenwen committed
489
490
                    # if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
                    #     weight.data = pad_weight(weight.data, 32)  
491
                        
zhuwenwen's avatar
zhuwenwen committed
492
493
494
                    # if self.use_fa_pad and (re.findall(qkv_words, layername)):
                    #     if not gemm_bank_conf(weight.data.shape[0]):
                    #         weight.data = pad_weight(weight.data, 32)
495
                                 
gaoqiong's avatar
gaoqiong committed
496
497
498
                    _weight = torch.zeros_like(weight.data)
                    ori_shape =_weight.shape
                    
zhuwenwen's avatar
zhuwenwen committed
499
                    ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
gaoqiong's avatar
gaoqiong committed
500
501
                    weight.data.copy_(_weight)
                    
zhuwenwen's avatar
zhuwenwen committed
502
                    weight.data=weight.data.reshape(ori_shape[1], -1)
gaoqiong's avatar
gaoqiong committed
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
     
        if self.quant_method == "awq":
            lay_key_words = [
                "self_attn.qkv_proj.qweight",
                "self_attn.o_proj.qweight",
                "mlp.gate_up_proj.qweight",
                "mlp.down_proj.qweight"
            ]
            combined_words = "|".join(lay_key_words)
            
            for layername, weight in params_dict.items():
                
                matches = re.findall(combined_words, layername)
                if matches:
                    qweight =params_dict[layername]
                    qzeros=params_dict[layername.replace("qweight", "qzeros")]
                    scales=params_dict[layername.replace("qweight", "scales")]
                    zeros_and_scalse =params_dict[layername.replace("qweight", "zeros_and_scales")]
                    
                    group_size= self.quant_config.group_size 
                   
                    dim_n = scales.data.shape[1]
                    dim_k = qweight.data.shape[0]
                    pad_group=2              
                    
gaoqiong's avatar
gaoqiong committed
528
                    _qw, _sz=ops.convert_s4(qweight,qzeros,scales,int(group_size)) 
gaoqiong's avatar
gaoqiong committed
529
                    
gaoqiong's avatar
gaoqiong committed
530
                    sz = ops.sz_permute(_sz).reshape(-1,dim_n)       
gaoqiong's avatar
gaoqiong committed
531
532
533
                    
                    zeros_and_scalse.data.copy_(sz)
                    qweight.data.copy_(_qw)
gaoqiong's avatar
gaoqiong committed
534
                    
gaoqiong's avatar
gaoqiong committed
535
536
537
538
                    #reshape
                    zeros_and_scalse.data=zeros_and_scalse.reshape(dim_n,-1)    #[k/greop_size,n]------>[n,k/group_size]
                    qweight.data=qweight.data.reshape(dim_n,-1)                      #[k,n/8]---->[n,k/8]  
                
zhuwenwen's avatar
zhuwenwen committed
539
                    if dim_k % 4096==0 and self.use_awq_pad:
gaoqiong's avatar
gaoqiong committed
540
541
542
543
                        zeros_and_scalse_pad= torch.zeros(dim_n,pad_group,dtype=torch.int32).cuda()
                        zeros_and_scalse.data=torch.cat((zeros_and_scalse.data,zeros_and_scalse_pad),dim=1).contiguous()
                        qweight_pad= torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda()
                        qweight.data=torch.cat((qweight.data,qweight_pad),dim=1).contiguous()
zhuwenwen's avatar
zhuwenwen committed
544
            
gaoqiong's avatar
gaoqiong committed
545
        #当为triton支持推理的时候不能进行处理
zhuwenwen's avatar
zhuwenwen committed
546
547
548
549
550
551
552
553
        if self.quant_method == "compressed_tensors":
            lay_key_words = [
                "self_attn.qkv_proj.weight",
                "self_attn.o_proj.weight",
                "mlp.gate_up_proj.weight",
                "mlp.down_proj.weight",
            ]
            combined_words = "|".join(lay_key_words)
gaoqiong's avatar
gaoqiong committed
554
555
            weight_shapes=[]
            all_json={}
zhuwenwen's avatar
zhuwenwen committed
556
557
558
            
            for layername, weight in params_dict.items():  
                matches = re.findall(combined_words, layername)
gaoqiong's avatar
gaoqiong committed
559
                if matches and "scale" not in layername:
zhuwenwen's avatar
zhuwenwen committed
560
                    weight_data =params_dict[layername]
gaoqiong's avatar
gaoqiong committed
561
                    n=weight_data.shape[0]
zhuwenwen's avatar
zhuwenwen committed
562
563
564
565
                    # k=weight_data.shape[1]
                    
                    # #判断当前size是否在优化的范围内,假如存在则走triton,假如不存在则走rocblas
                    # json_file=self.tritonsingleton.get_w8a8json_name(n,k)
gaoqiong's avatar
gaoqiong committed
566
567
568
569
570
571
572
573
574
575
576
577
578
                    
                    #rocblas和cutlass目前都需要weight做处理,但是triton不用
                    if self.w8a8_strategy!=1:
                        _weight=weight_data.T.contiguous().reshape(n,-1)
                        weight_data.data.copy_(_weight)  
                    
                    #下面是针对模型记录模型出现k和n值 
                    elif len(weight_shapes)<4: 
                        k=weight_data.shape[1]
                        weight_shapes.append({n,k})
                
                        json_file=self.tritonsingleton.get_w8a8json_name(n,k)
                        configs_dict=self.tritonsingleton.get_triton_cache(json_file,n,k)
gaoqiong's avatar
gaoqiong committed
579
580
                        if configs_dict:
                            all_json.update(configs_dict)
gaoqiong's avatar
gaoqiong committed
581
582
583
584
585
586
587
588
                                              
            if self.w8a8_strategy==1:
                self.tritonsingleton.triton_json_dict.append(all_json)
                #找到的所有config都进行一次warmup
                for key, value in all_json.items():
                    m=int(key.split('_')[0])
                    n=int(key.split('_')[1])
                    k=int(key.split('_')[2])
gaoqiong's avatar
gaoqiong committed
589
                    ops.triton_int8_gemm_helper(m=m,n=n,k=k,per_token_act_quant=True,per_out_channel_weight_quant=True,use_bias=False,best_config=value)
zhuwenwen's avatar
zhuwenwen committed
590
                    
591
        return loaded_params
592

Woosuk Kwon's avatar
Woosuk Kwon committed
593

594
class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
Terry's avatar
Terry committed
595
    packed_modules_mapping = {
596
597
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"]
Terry's avatar
Terry committed
598
599
600
601
    }

    # LoRA specific attributes
    supported_lora_modules = [
602
603
        "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens",
        "lm_head"
Terry's avatar
Terry committed
604
605
606
    ]
    embedding_modules = {
        "embed_tokens": "input_embeddings",
607
        "lm_head": "output_embeddings"
Terry's avatar
Terry committed
608
609
    }
    embedding_padding_modules = ["lm_head"]
610

611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
    # Mistral/Llama models can also be loaded with --load-format mistral
    # from consolidated.safetensors checkpoints
    mistral_mapping = {
        "layers": "model.layers",
        "attention": "self_attn",
        "wq": "q_proj",
        "wk": "k_proj",
        "wv": "v_proj",
        "wo": "o_proj",
        "attention_norm": "input_layernorm",
        "feed_forward": "mlp",
        "w1": "gate_proj",
        "w2": "down_proj",
        "w3": "up_proj",
        "ffn_norm": "post_attention_layernorm",
        "tok_embeddings": "model.embed_tokens",
        "output": "lm_head",
        "norm": "model.norm"
    }
630

631
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Woosuk Kwon's avatar
Woosuk Kwon committed
632
        super().__init__()
633
634
635
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
Woosuk Kwon's avatar
Woosuk Kwon committed
636
        self.config = config
637
        self.lora_config = lora_config
zhuwenwen's avatar
zhuwenwen committed
638
        
639
640
        self.model = self._init_model(vllm_config=vllm_config,
                                      prefix=maybe_prefix(prefix, "model"))
zhuwenwen's avatar
zhuwenwen committed
641
642
        
        self.tritonsingleton= W8a8GetCacheJSON()
643

644
645
646
647
648
649
650
651
        if get_pp_group().is_last_rank:
            self.unpadded_vocab_size = config.vocab_size
            if lora_config:
                self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
            self.lm_head = ParallelLMHead(
                self.unpadded_vocab_size,
                config.hidden_size,
                org_num_embeddings=config.vocab_size,
652
653
654
655
656
657
                padding_size=(
                    DEFAULT_VOCAB_PADDING_SIZE
                    # We need bigger padding if using lora for kernel
                    # compatibility
                    if not lora_config else
                    lora_config.lora_vocab_padding_size),
658
                quant_config=quant_config,
659
                prefix=maybe_prefix(prefix, "lm_head"),
660
661
            )
            if config.tie_word_embeddings:
662
663
                self.lm_head = self.lm_head.tie_weights(
                    self.model.embed_tokens)
664
665
666
667
668
669
670

            logit_scale = getattr(config, "logit_scale", 1.0)
            self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                    config.vocab_size,
                                                    logit_scale)
        else:
            self.lm_head = PPMissingLayer()
671

672
673
        self.sampler = get_sampler()

674
675
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
zhuwenwen's avatar
zhuwenwen committed
676
        
Woosuk Kwon's avatar
Woosuk Kwon committed
677

678
679
680
    def _init_model(self, vllm_config: VllmConfig, prefix: str = ""):
        return LlamaModel(vllm_config=vllm_config, prefix=prefix)

681
682
683
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

Woosuk Kwon's avatar
Woosuk Kwon committed
684
685
    def forward(
        self,
686
687
        input_ids: torch.Tensor,
        positions: torch.Tensor,
688
689
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
690
        intermediate_tensors: Optional[IntermediateTensors] = None,
691
        inputs_embeds: Optional[torch.Tensor] = None,
692
693
    ) -> Union[torch.Tensor, IntermediateTensors]:
        model_output = self.model(input_ids, positions, kv_caches,
694
695
                                  attn_metadata, intermediate_tensors,
                                  inputs_embeds)
696
        return model_output
697

698
699
700
701
702
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
703
        logits = self.logits_processor(self.lm_head, hidden_states,
704
705
706
                                       sampling_metadata)
        return logits

707
708
    def sample(self, logits: torch.Tensor,
               sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
709
        next_tokens = self.sampler(logits, sampling_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
710
711
        return next_tokens

zhuwenwen's avatar
zhuwenwen committed
712

713
714
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
715
716
717
718
719
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=(["lm_head."]
                           if self.config.tie_word_embeddings else None),
        )
720
        return loader.load_weights(
721
            self.maybe_remap_mistral(name, loaded_weight)
722
            for name, loaded_weight in weights)
723

zhuwenwen's avatar
zhuwenwen committed
724

725
726
727
    # This function is used to remap the mistral format as
    # used by Mistral and Llama <=2
    def maybe_remap_mistral(
728
729
730
731
        self,
        name: str,
        loaded_weight: torch.Tensor,
    ) -> Tuple[str, torch.Tensor]:
732

733
        def permute(w: torch.Tensor, n_heads: int):
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
            attn_in = self.config.head_dim * n_heads
            attn_out = self.config.hidden_size

            return w.view(n_heads, attn_in // n_heads // 2, 2,
                          attn_out).transpose(1, 2).reshape(attn_in, attn_out)

        mapping = self.mistral_mapping
        modules = name.split(".")

        # rotary embeds should be sliced
        if "wk" in modules:
            loaded_weight = permute(loaded_weight,
                                    self.config.num_key_value_heads)
        elif "wq" in modules:
            loaded_weight = permute(loaded_weight,
                                    self.config.num_attention_heads)

        for item in modules:
            if item in mapping and mapping[item] not in name:
                name = name.replace(item, mapping[item])

        return name, loaded_weight