llama.py 32.1 KB
Newer Older
1
2
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
Woosuk Kwon's avatar
Woosuk Kwon committed
3
# Copyright 2023 The vLLM team.
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
22
"""Inference-only LLaMA model compatible with HuggingFace weights."""
23
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
24
25
26
27

import torch
from torch import nn
from transformers import LlamaConfig
zhuwenwen's avatar
zhuwenwen committed
28
import os
gaoqiong's avatar
gaoqiong committed
29
import re
Woosuk Kwon's avatar
Woosuk Kwon committed
30

31
from vllm.attention import Attention, AttentionMetadata
32
from vllm.compilation.decorators import support_torch_compile
33
from vllm.config import CacheConfig, VllmConfig
34
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
35
                              get_tensor_model_parallel_world_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
36
from vllm.model_executor.layers.activation import SiluAndMul
37
from vllm.model_executor.layers.layernorm import RMSNorm
38
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
39
40
                                               QKVParallelLinear,
                                               RowParallelLinear)
41
from vllm.model_executor.layers.logits_processor import LogitsProcessor
42
from vllm.model_executor.layers.quantization import QuantizationConfig
43
44
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
    get_compressed_tensors_cache_scale)
45
from vllm.model_executor.layers.rotary_embedding import get_rope
Joe Runde's avatar
Joe Runde committed
46
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
47
from vllm.model_executor.layers.vocab_parallel_embedding import (
48
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
49
from vllm.model_executor.model_loader.weight_utils import (
50
    default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
51
from vllm.model_executor.sampling_metadata import SamplingMetadata
52
from vllm.platforms import current_platform
53
from vllm.sequence import IntermediateTensors
zhuwenwen's avatar
zhuwenwen committed
54
from vllm.utils import W8a8GetCacheJSON
Woosuk Kwon's avatar
Woosuk Kwon committed
55

56
from .interfaces import SupportsLoRA, SupportsPP
57
58
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
                    is_pp_missing_parameter,
59
60
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
61

gaoqiong's avatar
gaoqiong committed
62
from vllm import _custom_ops as ops
63
64
from vllm.model_executor.utils import pad_weight, gemm_bank_conf

Woosuk Kwon's avatar
Woosuk Kwon committed
65
66

class LlamaMLP(nn.Module):
67

Woosuk Kwon's avatar
Woosuk Kwon committed
68
69
    def __init__(
        self,
70
71
72
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
73
        quant_config: Optional[QuantizationConfig] = None,
74
        bias: bool = False,
75
        prefix: str = "",
76
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
77
        super().__init__()
78
        self.gate_up_proj = MergedColumnParallelLinear(
79
80
            input_size=hidden_size,
            output_sizes=[intermediate_size] * 2,
81
            bias=bias,
82
            quant_config=quant_config,
83
84
85
86
87
88
89
90
91
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            input_size=intermediate_size,
            output_size=hidden_size,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.down_proj",
        )
92
93
94
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
Woosuk Kwon's avatar
Woosuk Kwon committed
95
        self.act_fn = SiluAndMul()
Woosuk Kwon's avatar
Woosuk Kwon committed
96
97

    def forward(self, x):
98
99
        x, _ = self.gate_up_proj(x)
        x = self.act_fn(x)
Woosuk Kwon's avatar
Woosuk Kwon committed
100
101
102
103
104
105
106
107
        x, _ = self.down_proj(x)
        return x


class LlamaAttention(nn.Module):

    def __init__(
        self,
108
        config: LlamaConfig,
109
110
111
112
113
114
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
        max_position_embeddings: int = 8192,
115
        quant_config: Optional[QuantizationConfig] = None,
116
        bias: bool = False,
117
        cache_config: Optional[CacheConfig] = None,
118
        prefix: str = "",
119
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
120
        super().__init__()
121
        layer_idx = extract_layer_index(prefix)
122
        self.hidden_size = hidden_size
Zhuohan Li's avatar
Zhuohan Li committed
123
        tp_size = get_tensor_model_parallel_world_size()
124
        self.total_num_heads = num_heads
Zhuohan Li's avatar
Zhuohan Li committed
125
126
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
127
        self.total_num_kv_heads = num_kv_heads
128
129
130
131
132
133
134
135
136
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
137
138
139
        # MistralConfig has an optional head_dim introduced by Mistral-Nemo
        self.head_dim = getattr(config, "head_dim",
                                self.hidden_size // self.total_num_heads)
Zhuohan Li's avatar
Zhuohan Li committed
140
141
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
142
        self.scaling = self.head_dim**-0.5
143
144
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings
Woosuk Kwon's avatar
Woosuk Kwon committed
145

146
        self.qkv_proj = QKVParallelLinear(
147
148
149
150
            hidden_size=hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_kv_heads,
151
            bias=bias,
152
            quant_config=quant_config,
153
            prefix=f"{prefix}.qkv_proj",
Woosuk Kwon's avatar
Woosuk Kwon committed
154
        )
155

156
        self.o_proj = RowParallelLinear(
157
158
            input_size=self.total_num_heads * self.head_dim,
            output_size=hidden_size,
159
            bias=bias,
160
            quant_config=quant_config,
161
            prefix=f"{prefix}.o_proj",
Woosuk Kwon's avatar
Woosuk Kwon committed
162
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
163

164
165
166
167
        is_neox_style = True
        if quant_config is not None and quant_config.get_name() == "gguf":
            is_neox_style = False

168
169
170
171
172
173
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
174
            is_neox_style=is_neox_style,
175
        )
176
177

        if hasattr(config, "interleaved_sliding_window"):
178
179
180
181
182
183
            interleaved_sliding_window = config.interleaved_sliding_window
            if isinstance(interleaved_sliding_window, int):
                sliding_window = interleaved_sliding_window
            elif isinstance(interleaved_sliding_window, list):
                sw_idx = layer_idx % len(interleaved_sliding_window)
                sliding_window = interleaved_sliding_window[sw_idx]
184
            else:
185
186
                raise ValueError(
                    f"{type(interleaved_sliding_window)} is not supported.")
187
188
189
        else:
            sliding_window = None

190
191
192
193
194
195
196
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
197
            per_layer_sliding_window=sliding_window,
198
            prefix=f"{prefix}.attn",
199
        )
200
201
202
203
204
        
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
Woosuk Kwon's avatar
Woosuk Kwon committed
205
206
207

    def forward(
        self,
208
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
209
        hidden_states: torch.Tensor,
210
211
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
212
    ) -> torch.Tensor:
213
        qkv, _ = self.qkv_proj(hidden_states)
214
        if os.environ.get('FA_PAD') == '1' and self.quant_method is None:
215
            qkv = qkv[...,:-32]
Zhuohan Li's avatar
Zhuohan Li committed
216
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
217
        q, k = self.rotary_emb(positions, q, k)
218
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
219
220
221
222
223
224
        output, _ = self.o_proj(attn_output)
        return output


class LlamaDecoderLayer(nn.Module):

225
226
227
    def __init__(
        self,
        config: LlamaConfig,
228
        cache_config: Optional[CacheConfig] = None,
229
        quant_config: Optional[QuantizationConfig] = None,
230
        prefix: str = "",
231
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
232
233
        super().__init__()
        self.hidden_size = config.hidden_size
234
235
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
236
237
238
239
        if rope_scaling is not None and getattr(
                config, "original_max_position_embeddings", None):
            rope_scaling["original_max_position_embeddings"] = (
                config.original_max_position_embeddings)
240
241
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
242
243
244
245
        # Support abacusai/Smaug-72B-v0.1 with attention_bias
        # Support internlm/internlm-7b with bias
        attention_bias = getattr(config, "attention_bias", False) or getattr(
            config, "bias", False)
Woosuk Kwon's avatar
Woosuk Kwon committed
246
        self.self_attn = LlamaAttention(
247
            config=config,
248
249
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
250
251
            num_kv_heads=getattr(config, "num_key_value_heads",
                                 config.num_attention_heads),
252
253
254
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
255
            quant_config=quant_config,
256
            bias=attention_bias,
257
            cache_config=cache_config,
258
            prefix=f"{prefix}.self_attn",
Woosuk Kwon's avatar
Woosuk Kwon committed
259
260
        )
        self.mlp = LlamaMLP(
261
262
263
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
264
            quant_config=quant_config,
265
            bias=getattr(config, "mlp_bias", False),
266
            prefix=f"{prefix}.mlp",
Woosuk Kwon's avatar
Woosuk Kwon committed
267
        )
268
269
270
271
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)
Woosuk Kwon's avatar
Woosuk Kwon committed
272
273
274

    def forward(
        self,
275
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
276
        hidden_states: torch.Tensor,
277
278
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
279
280
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
281
        # Self Attention
282
283
284
285
286
287
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
288
289
290
291
        hidden_states = self.self_attn(positions=positions,
                                       hidden_states=hidden_states,
                                       kv_cache=kv_cache,
                                       attn_metadata=attn_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
292
293

        # Fully Connected
294
295
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
296
        hidden_states = self.mlp(hidden_states)
297
        return hidden_states, residual
Woosuk Kwon's avatar
Woosuk Kwon committed
298
299


300
@support_torch_compile
Woosuk Kwon's avatar
Woosuk Kwon committed
301
302
class LlamaModel(nn.Module):

303
304
305
306
307
    def __init__(self,
                 *,
                 vllm_config: VllmConfig,
                 prefix: str = "",
                 layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer):
Woosuk Kwon's avatar
Woosuk Kwon committed
308
        super().__init__()
309
310
311
312
313
314

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config

Woosuk Kwon's avatar
Woosuk Kwon committed
315
316
        self.config = config
        self.padding_idx = config.pad_token_id
317
318
319
320
        lora_vocab = (lora_config.lora_extra_vocab_size *
                      (lora_config.max_loras or 1)) if lora_config else 0
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size
321
322
323
324
325
326
        if get_pp_group().is_first_rank or (config.tie_word_embeddings
                                            and get_pp_group().is_last_rank):
            self.embed_tokens = VocabParallelEmbedding(
                self.vocab_size,
                config.hidden_size,
                org_num_embeddings=config.vocab_size,
327
                quant_config=quant_config,
328
329
330
            )
        else:
            self.embed_tokens = PPMissingLayer()
331
        self.start_layer, self.end_layer, self.layers = make_layers(
332
            config.num_hidden_layers,
333
334
335
336
            lambda prefix: layer_type(config=config,
                                      cache_config=cache_config,
                                      quant_config=quant_config,
                                      prefix=prefix),
337
338
            prefix=f"{prefix}.layers",
        )
339
340
341
342
        if get_pp_group().is_last_rank:
            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        else:
            self.norm = PPMissingLayer()
Woosuk Kwon's avatar
Woosuk Kwon committed
343

344
345
346
347
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))

348
349
350
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

Woosuk Kwon's avatar
Woosuk Kwon committed
351
352
    def forward(
        self,
353
        input_ids: Optional[torch.Tensor],
354
        positions: torch.Tensor,
355
356
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
357
        intermediate_tensors: Optional[IntermediateTensors],
358
        inputs_embeds: Optional[torch.Tensor] = None,
359
360
361
362
363
364
365
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
            residual = None
366
        else:
367
368
369
370
371
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

        for i in range(self.start_layer, self.end_layer):
Woosuk Kwon's avatar
Woosuk Kwon committed
372
            layer = self.layers[i]
373
374
375
            hidden_states, residual = layer(positions, hidden_states,
                                            kv_caches[i - self.start_layer],
                                            attn_metadata, residual)
376
377
378
379
380
381
382

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })

383
        hidden_states, _ = self.norm(hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
384
385
        return hidden_states

386
387
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
388
389
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
390
391
392
393
394
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
Zhuohan Li's avatar
Zhuohan Li committed
395
        ]
396
        params_dict = dict(self.named_parameters())
397
        loaded_params: Set[str] = set()
398
        for name, loaded_weight in weights:
399
400
            if "rotary_emb.inv_freq" in name:
                continue
401
402
403
404
            if ("rotary_emb.cos_cached" in name
                    or "rotary_emb.sin_cached" in name):
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
405
                continue
406
407
408
409
410
411
412
            if scale_name := get_compressed_tensors_cache_scale(name):
                # Loading kv cache scales for compressed-tensors quantization
                param = params_dict[scale_name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                loaded_weight = loaded_weight[0]
                weight_loader(param, loaded_weight)
413
                loaded_params.add(scale_name)
414
                continue
415
            for param_name, weight_name, shard_id in stacked_params_mapping:
Zhuohan Li's avatar
Zhuohan Li committed
416
                if weight_name not in name:
417
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
418
419
420
421
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
422
423
424
425

                if is_pp_missing_parameter(name, self):
                    continue

CHU Tianxiang's avatar
CHU Tianxiang committed
426
                param = params_dict[name]
427
428
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
429
                break
430
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
431
432
433
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
434
                # Remapping the name of FP8 kv-scale.
435
436
437
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue
438
439
440
441

                if is_pp_missing_parameter(name, self):
                    continue

442
443
444
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
445
                weight_loader(param, loaded_weight)
446
            loaded_params.add(name)
gaoqiong's avatar
gaoqiong committed
447
            
448
        if self.use_llama_nn and self.quant_method is None :
gaoqiong's avatar
gaoqiong committed
449
450
451
452
            lay_key_words = [
                "self_attn.qkv_proj.weight",
                "self_attn.o_proj.weight",
                "mlp.gate_up_proj.weight",
453
                "mlp.down_proj.weight",
gaoqiong's avatar
gaoqiong committed
454
455
456
            ]
            combined_words = "|".join(lay_key_words)
            
zhuwenwen's avatar
zhuwenwen committed
457
458
459
            lay_qkv_words = ["self_attn.qkv_proj.weight"]   
            qkv_words = "|".join(lay_qkv_words)          
            
gaoqiong's avatar
gaoqiong committed
460
            for layername, weight in params_dict.items():
zhuwenwen's avatar
zhuwenwen committed
461
                if "lm_head.weight" in layername and weight.shape[1] >= 4096:
462
463
                    lay_key_words.append("lm_head.weight")
                    combined_words = "|".join(lay_key_words)
464
465
466
                    os.environ['LM_NN'] = '1'  
                else:
                    os.environ['LM_NN'] = '0' 
467
                    
gaoqiong's avatar
gaoqiong committed
468
                matches = re.findall(combined_words, layername)
469
                
470
471
472
473
                if matches:         
                    if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
                        weight.data = pad_weight(weight.data, 32)  
                        
zhuwenwen's avatar
zhuwenwen committed
474
475
476
                    if self.use_fa_pad and (re.findall(qkv_words, layername)):
                        if not gemm_bank_conf(weight.data.shape[0]):
                            weight.data = pad_weight(weight.data, 32)
477
                                 
gaoqiong's avatar
gaoqiong committed
478
479
480
                    _weight = torch.zeros_like(weight.data)
                    ori_shape =_weight.shape
                    
zhuwenwen's avatar
zhuwenwen committed
481
                    ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
gaoqiong's avatar
gaoqiong committed
482
483
                    weight.data.copy_(_weight)
                    
zhuwenwen's avatar
zhuwenwen committed
484
                    weight.data=weight.data.reshape(ori_shape[1], -1)
gaoqiong's avatar
gaoqiong committed
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
     
        if self.quant_method == "awq":
            lay_key_words = [
                "self_attn.qkv_proj.qweight",
                "self_attn.o_proj.qweight",
                "mlp.gate_up_proj.qweight",
                "mlp.down_proj.qweight"
            ]
            combined_words = "|".join(lay_key_words)
            
            for layername, weight in params_dict.items():
                
                matches = re.findall(combined_words, layername)
                if matches:
                    qweight =params_dict[layername]
                    qzeros=params_dict[layername.replace("qweight", "qzeros")]
                    scales=params_dict[layername.replace("qweight", "scales")]
                    zeros_and_scalse =params_dict[layername.replace("qweight", "zeros_and_scales")]
                    
                    group_size= self.quant_config.group_size 
                   
                    dim_n = scales.data.shape[1]
                    dim_k = qweight.data.shape[0]
                    pad_group=2              
                    
gaoqiong's avatar
gaoqiong committed
510
                    _qw, _sz=ops.convert_s4(qweight,qzeros,scales,int(group_size)) 
gaoqiong's avatar
gaoqiong committed
511
                    
gaoqiong's avatar
gaoqiong committed
512
                    sz = ops.sz_permute(_sz).reshape(-1,dim_n)       
gaoqiong's avatar
gaoqiong committed
513
514
515
                    
                    zeros_and_scalse.data.copy_(sz)
                    qweight.data.copy_(_qw)
gaoqiong's avatar
gaoqiong committed
516
                    
gaoqiong's avatar
gaoqiong committed
517
518
519
520
                    #reshape
                    zeros_and_scalse.data=zeros_and_scalse.reshape(dim_n,-1)    #[k/greop_size,n]------>[n,k/group_size]
                    qweight.data=qweight.data.reshape(dim_n,-1)                      #[k,n/8]---->[n,k/8]  
                
zhuwenwen's avatar
zhuwenwen committed
521
                    if dim_k % 4096==0 and self.use_awq_pad:
gaoqiong's avatar
gaoqiong committed
522
523
524
525
                        zeros_and_scalse_pad= torch.zeros(dim_n,pad_group,dtype=torch.int32).cuda()
                        zeros_and_scalse.data=torch.cat((zeros_and_scalse.data,zeros_and_scalse_pad),dim=1).contiguous()
                        qweight_pad= torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda()
                        qweight.data=torch.cat((qweight.data,qweight_pad),dim=1).contiguous()
zhuwenwen's avatar
zhuwenwen committed
526
            
gaoqiong's avatar
gaoqiong committed
527
        #当为triton支持推理的时候不能进行处理
zhuwenwen's avatar
zhuwenwen committed
528
529
530
531
532
533
534
535
        if self.quant_method == "compressed_tensors":
            lay_key_words = [
                "self_attn.qkv_proj.weight",
                "self_attn.o_proj.weight",
                "mlp.gate_up_proj.weight",
                "mlp.down_proj.weight",
            ]
            combined_words = "|".join(lay_key_words)
gaoqiong's avatar
gaoqiong committed
536
537
            weight_shapes=[]
            all_json={}
zhuwenwen's avatar
zhuwenwen committed
538
539
540
            
            for layername, weight in params_dict.items():  
                matches = re.findall(combined_words, layername)
gaoqiong's avatar
gaoqiong committed
541
                if matches and "scale" not in layername:
zhuwenwen's avatar
zhuwenwen committed
542
                    weight_data =params_dict[layername]
gaoqiong's avatar
gaoqiong committed
543
544
545
546
547
548
549
550
551
552
553
554
555
556
                    n=weight_data.shape[0]
                    
                    #rocblas和cutlass目前都需要weight做处理,但是triton不用
                    if self.w8a8_strategy!=1:
                        _weight=weight_data.T.contiguous().reshape(n,-1)
                        weight_data.data.copy_(_weight)  
                    
                    #下面是针对模型记录模型出现k和n值 
                    elif len(weight_shapes)<4: 
                        k=weight_data.shape[1]
                        weight_shapes.append({n,k})
                
                        json_file=self.tritonsingleton.get_w8a8json_name(n,k)
                        configs_dict=self.tritonsingleton.get_triton_cache(json_file,n,k)
gaoqiong's avatar
gaoqiong committed
557
558
                        if configs_dict:
                            all_json.update(configs_dict)
gaoqiong's avatar
gaoqiong committed
559
560
561
562
563
564
565
566
                                              
            if self.w8a8_strategy==1:
                self.tritonsingleton.triton_json_dict.append(all_json)
                #找到的所有config都进行一次warmup
                for key, value in all_json.items():
                    m=int(key.split('_')[0])
                    n=int(key.split('_')[1])
                    k=int(key.split('_')[2])
gaoqiong's avatar
gaoqiong committed
567
                    ops.triton_int8_gemm_helper(m=m,n=n,k=k,per_token_act_quant=True,per_out_channel_weight_quant=True,use_bias=False,best_config=value)
zhuwenwen's avatar
zhuwenwen committed
568
                    
569
        return loaded_params
570

571
572
573
574
575
576
577
578
579
580
    # If this function is called, it should always initialize KV cache scale
    # factors (or else raise an exception). Thus, handled exceptions should
    # make sure to leave KV cache scale factors in a known good (dummy) state
    def load_kv_cache_scales(self, quantization_param_path: str) -> None:
        tp_size = get_tensor_model_parallel_world_size()
        tp_rank = get_tensor_model_parallel_rank()
        for layer_idx, scaling_factor in kv_cache_scales_loader(
                quantization_param_path, tp_rank, tp_size,
                self.config.num_hidden_layers,
                self.config.__class__.model_type):
581
582
            if not isinstance(self.layers[layer_idx], nn.Identity):
                layer_self_attn = self.layers[layer_idx].self_attn
583

584
            if current_platform.is_rocm():
585
586
587
588
589
590
                # The scaling factor convention we are assuming is
                # quantized_value * scaling_factor ~= true_value
                # which is consistent with the practice of setting
                # scaling_factor = tensor_amax / FPtype_max
                scaling_factor *= 2
            if hasattr(layer_self_attn, "kv_scale"):
591
                layer_self_attn.attn._kv_scale = scaling_factor
592
593
594
            else:
                raise RuntimeError("Self attention has no KV cache scaling "
                                   "factor attribute!")
595

Woosuk Kwon's avatar
Woosuk Kwon committed
596

597
class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
Terry's avatar
Terry committed
598
    packed_modules_mapping = {
599
600
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"]
Terry's avatar
Terry committed
601
602
603
604
    }

    # LoRA specific attributes
    supported_lora_modules = [
605
606
        "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens",
        "lm_head"
Terry's avatar
Terry committed
607
608
609
    ]
    embedding_modules = {
        "embed_tokens": "input_embeddings",
610
        "lm_head": "output_embeddings"
Terry's avatar
Terry committed
611
612
    }
    embedding_padding_modules = ["lm_head"]
613
614

    # BitandBytes specific attributes
615
616
617
618
619
620
621
622
    bitsandbytes_stacked_params_mapping = {
        # shard_name, weight_name, index
        "q_proj": ("qkv_proj", 0),
        "k_proj": ("qkv_proj", 1),
        "v_proj": ("qkv_proj", 2),
        "gate_proj": ("gate_up_proj", 0),
        "up_proj": ("gate_up_proj", 1),
    }
623

624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
    # Mistral/Llama models can also be loaded with --load-format mistral
    # from consolidated.safetensors checkpoints
    mistral_mapping = {
        "layers": "model.layers",
        "attention": "self_attn",
        "wq": "q_proj",
        "wk": "k_proj",
        "wv": "v_proj",
        "wo": "o_proj",
        "attention_norm": "input_layernorm",
        "feed_forward": "mlp",
        "w1": "gate_proj",
        "w2": "down_proj",
        "w3": "up_proj",
        "ffn_norm": "post_attention_layernorm",
        "tok_embeddings": "model.embed_tokens",
        "output": "lm_head",
        "norm": "model.norm"
    }
643

644
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Woosuk Kwon's avatar
Woosuk Kwon committed
645
        super().__init__()
646
647
648
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
Woosuk Kwon's avatar
Woosuk Kwon committed
649
        self.config = config
650
        self.lora_config = lora_config
zhuwenwen's avatar
zhuwenwen committed
651
        
652
653
        self.model = self._init_model(vllm_config=vllm_config,
                                      prefix=maybe_prefix(prefix, "model"))
zhuwenwen's avatar
zhuwenwen committed
654
655
        
        self.tritonsingleton= W8a8GetCacheJSON()
656

657
658
659
660
661
662
663
664
        if get_pp_group().is_last_rank:
            self.unpadded_vocab_size = config.vocab_size
            if lora_config:
                self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
            self.lm_head = ParallelLMHead(
                self.unpadded_vocab_size,
                config.hidden_size,
                org_num_embeddings=config.vocab_size,
665
666
667
668
669
670
                padding_size=(
                    DEFAULT_VOCAB_PADDING_SIZE
                    # We need bigger padding if using lora for kernel
                    # compatibility
                    if not lora_config else
                    lora_config.lora_vocab_padding_size),
671
                quant_config=quant_config,
672
                prefix=maybe_prefix(prefix, "lm_head"),
673
674
            )
            if config.tie_word_embeddings:
675
676
                self.lm_head = self.lm_head.tie_weights(
                    self.model.embed_tokens)
677
678
679
680
681
682
683

            logit_scale = getattr(config, "logit_scale", 1.0)
            self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                    config.vocab_size,
                                                    logit_scale)
        else:
            self.lm_head = PPMissingLayer()
684

685
686
        self.sampler = get_sampler()

687
688
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
zhuwenwen's avatar
zhuwenwen committed
689
690
691
692
693
694
695
696
697
698
699
700
        
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
              
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
        # self.use_lm_nn = os.environ.get('LM_NN') == '1'
        self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
        self.use_fa_pad = os.environ.get('FA_PAD') == '1'
        self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
        self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '0'))
Woosuk Kwon's avatar
Woosuk Kwon committed
701

702
703
704
    def _init_model(self, vllm_config: VllmConfig, prefix: str = ""):
        return LlamaModel(vllm_config=vllm_config, prefix=prefix)

705
706
707
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

Woosuk Kwon's avatar
Woosuk Kwon committed
708
709
    def forward(
        self,
710
711
        input_ids: torch.Tensor,
        positions: torch.Tensor,
712
713
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
714
        intermediate_tensors: Optional[IntermediateTensors] = None,
715
        inputs_embeds: Optional[torch.Tensor] = None,
716
717
    ) -> Union[torch.Tensor, IntermediateTensors]:
        model_output = self.model(input_ids, positions, kv_caches,
718
719
                                  attn_metadata, intermediate_tensors,
                                  inputs_embeds)
720
        return model_output
721

722
723
724
725
726
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
727
        logits = self.logits_processor(self.lm_head, hidden_states,
728
729
730
                                       sampling_metadata)
        return logits

731
732
    def sample(self, logits: torch.Tensor,
               sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
733
        next_tokens = self.sampler(logits, sampling_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
734
735
        return next_tokens

zhuwenwen's avatar
zhuwenwen committed
736

737
738
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
739
740
741
742
743
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=(["lm_head."]
                           if self.config.tie_word_embeddings else None),
        )
744
        return loader.load_weights(
745
            self.maybe_remap_mistral(name, loaded_weight)
746
            for name, loaded_weight in weights)
747

zhuwenwen's avatar
zhuwenwen committed
748

749
    def load_kv_cache_scales(self, quantization_param_path: str) -> None:
750
        self.model.load_kv_cache_scales(quantization_param_path)
751
752
753
754

    # This function is used to remap the mistral format as
    # used by Mistral and Llama <=2
    def maybe_remap_mistral(
755
756
757
758
        self,
        name: str,
        loaded_weight: torch.Tensor,
    ) -> Tuple[str, torch.Tensor]:
759

760
        def permute(w: torch.Tensor, n_heads: int):
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
            attn_in = self.config.head_dim * n_heads
            attn_out = self.config.hidden_size

            return w.view(n_heads, attn_in // n_heads // 2, 2,
                          attn_out).transpose(1, 2).reshape(attn_in, attn_out)

        mapping = self.mistral_mapping
        modules = name.split(".")

        # rotary embeds should be sliced
        if "wk" in modules:
            loaded_weight = permute(loaded_weight,
                                    self.config.num_key_value_heads)
        elif "wq" in modules:
            loaded_weight = permute(loaded_weight,
                                    self.config.num_attention_heads)

        for item in modules:
            if item in mapping and mapping[item] not in name:
                name = name.replace(item, mapping[item])

        return name, loaded_weight