llama.py 31.7 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
Woosuk Kwon's avatar
Woosuk Kwon committed
5
# Copyright 2023 The vLLM team.
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
24
"""Inference-only LLaMA model compatible with HuggingFace weights."""
25
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Type, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
26
27
28
29
30

import torch
from torch import nn
from transformers import LlamaConfig

zhuwenwen's avatar
zhuwenwen committed
31
import os
gaoqiong's avatar
gaoqiong committed
32
import re
33
import vllm.envs as envs
34
from vllm.attention import Attention
zhuwenwen's avatar
zhuwenwen committed
35

36
from vllm.compilation.decorators import support_torch_compile
37
from vllm.config import CacheConfig, VllmConfig
38
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
Woosuk Kwon's avatar
Woosuk Kwon committed
39
from vllm.model_executor.layers.activation import SiluAndMul
40
from vllm.model_executor.layers.layernorm import RMSNorm
41
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
42
43
                                               QKVParallelLinear,
                                               RowParallelLinear)
44
from vllm.model_executor.layers.logits_processor import LogitsProcessor
45
from vllm.model_executor.layers.quantization import QuantizationConfig
46
from vllm.model_executor.layers.rotary_embedding import get_rope
Joe Runde's avatar
Joe Runde committed
47
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
48
from vllm.model_executor.layers.vocab_parallel_embedding import (
49
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
50
from vllm.model_executor.model_loader.weight_utils import (
51
    default_weight_loader, maybe_remap_kv_scale_name)
52
from vllm.model_executor.sampling_metadata import SamplingMetadata
53
from vllm.sequence import IntermediateTensors
zhuwenwen's avatar
zhuwenwen committed
54
from vllm.utils import W8a8GetCacheJSON
Woosuk Kwon's avatar
Woosuk Kwon committed
55

56
from .interfaces import SupportsLoRA, SupportsPP
57
58
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
                    is_pp_missing_parameter,
59
60
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
61

gaoqiong's avatar
gaoqiong committed
62
from vllm import _custom_ops as ops
63
64
from vllm.model_executor.utils import pad_weight, gemm_bank_conf

Woosuk Kwon's avatar
Woosuk Kwon committed
65
66

class LlamaMLP(nn.Module):
67

Woosuk Kwon's avatar
Woosuk Kwon committed
68
69
    def __init__(
        self,
70
71
72
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
73
        quant_config: Optional[QuantizationConfig] = None,
74
        bias: bool = False,
75
        prefix: str = "",
76
        reduce_results: bool = True,
77
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
78
        super().__init__()
79
        self.gate_up_proj = MergedColumnParallelLinear(
80
81
            input_size=hidden_size,
            output_sizes=[intermediate_size] * 2,
82
            bias=bias,
83
            quant_config=quant_config,
84
85
86
87
88
89
90
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            input_size=intermediate_size,
            output_size=hidden_size,
            bias=bias,
            quant_config=quant_config,
91
            reduce_results=reduce_results,
92
93
            prefix=f"{prefix}.down_proj",
        )
94
95
96
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
Woosuk Kwon's avatar
Woosuk Kwon committed
97
        self.act_fn = SiluAndMul()
Woosuk Kwon's avatar
Woosuk Kwon committed
98
99

    def forward(self, x):
100
101
        x, _ = self.gate_up_proj(x)
        x = self.act_fn(x)
Woosuk Kwon's avatar
Woosuk Kwon committed
102
103
104
105
106
107
        x, _ = self.down_proj(x)
        return x


class LlamaAttention(nn.Module):

108
109
110
111
112
113
114
115
116
117
    def __init__(self,
                 config: LlamaConfig,
                 hidden_size: int,
                 num_heads: int,
                 num_kv_heads: int,
                 rope_theta: float = 10000,
                 rope_scaling: Optional[Dict[str, Any]] = None,
                 max_position_embeddings: int = 8192,
                 quant_config: Optional[QuantizationConfig] = None,
                 bias: bool = False,
118
                 bias_o_proj: bool = False,
119
                 cache_config: Optional[CacheConfig] = None,
120
                 prefix: str = "") -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
121
        super().__init__()
122
        layer_idx = extract_layer_index(prefix)
123
        self.hidden_size = hidden_size
Zhuohan Li's avatar
Zhuohan Li committed
124
        tp_size = get_tensor_model_parallel_world_size()
125
        self.total_num_heads = num_heads
Zhuohan Li's avatar
Zhuohan Li committed
126
127
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
128
        self.total_num_kv_heads = num_kv_heads
129
130
131
132
133
134
135
136
137
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
138
139
140
        # MistralConfig has an optional head_dim introduced by Mistral-Nemo
        self.head_dim = getattr(config, "head_dim",
                                self.hidden_size // self.total_num_heads)
Amit Garg's avatar
Amit Garg committed
141
142
143
        # Phi models introduced a partial_rotary_factor parameter in the config
        partial_rotary_factor = getattr(config, "partial_rotary_factor", 1)
        self.rotary_dim = int(partial_rotary_factor * self.head_dim)
Zhuohan Li's avatar
Zhuohan Li committed
144
145
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
146
        self.scaling = self.head_dim**-0.5
147
148
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings
Woosuk Kwon's avatar
Woosuk Kwon committed
149

150
        self.qkv_proj = QKVParallelLinear(
151
152
153
154
            hidden_size=hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_kv_heads,
155
            bias=bias,
156
            quant_config=quant_config,
157
            prefix=f"{prefix}.qkv_proj",
Woosuk Kwon's avatar
Woosuk Kwon committed
158
        )
159

160
        self.o_proj = RowParallelLinear(
161
162
            input_size=self.total_num_heads * self.head_dim,
            output_size=hidden_size,
163
            bias=bias_o_proj,
164
            quant_config=quant_config,
165
            prefix=f"{prefix}.o_proj",
Woosuk Kwon's avatar
Woosuk Kwon committed
166
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
167

168
        is_neox_style = True
169
170
        is_gguf = quant_config and quant_config.get_name() == "gguf"
        if is_gguf and config.model_type == "llama":
171
172
            is_neox_style = False

173
174
        self.rotary_emb = get_rope(
            self.head_dim,
Amit Garg's avatar
Amit Garg committed
175
            rotary_dim=self.rotary_dim,
176
177
178
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
179
            is_neox_style=is_neox_style,
180
        )
181
182

        if hasattr(config, "interleaved_sliding_window"):
183
184
185
186
187
188
            interleaved_sliding_window = config.interleaved_sliding_window
            if isinstance(interleaved_sliding_window, int):
                sliding_window = interleaved_sliding_window
            elif isinstance(interleaved_sliding_window, list):
                sw_idx = layer_idx % len(interleaved_sliding_window)
                sliding_window = interleaved_sliding_window[sw_idx]
189
            else:
190
191
                raise ValueError(
                    f"{type(interleaved_sliding_window)} is not supported.")
192
193
194
        else:
            sliding_window = None

195
196
197
198
199
200
201
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
202
            per_layer_sliding_window=sliding_window,
203
            prefix=f"{prefix}.attn",
204
        )
205
206
207
208
209
        
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
Woosuk Kwon's avatar
Woosuk Kwon committed
210
211
212

    def forward(
        self,
213
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
214
215
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
216
        qkv, _ = self.qkv_proj(hidden_states)
zhuwenwen's avatar
zhuwenwen committed
217
218
        # if os.environ.get('FA_PAD') == '1' and self.quant_method is None:
        #     qkv = qkv[...,:-32]
Zhuohan Li's avatar
Zhuohan Li committed
219
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
220
        q, k = self.rotary_emb(positions, q, k)
221
        attn_output = self.attn(q, k, v)
Woosuk Kwon's avatar
Woosuk Kwon committed
222
223
224
225
226
227
        output, _ = self.o_proj(attn_output)
        return output


class LlamaDecoderLayer(nn.Module):

228
229
230
    def __init__(
        self,
        config: LlamaConfig,
231
        cache_config: Optional[CacheConfig] = None,
232
        quant_config: Optional[QuantizationConfig] = None,
233
        prefix: str = "",
234
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
235
236
        super().__init__()
        self.hidden_size = config.hidden_size
237
238
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
239
240
241
242
        if rope_scaling is not None and getattr(
                config, "original_max_position_embeddings", None):
            rope_scaling["original_max_position_embeddings"] = (
                config.original_max_position_embeddings)
243
244
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
245
246
247
248
        # Support abacusai/Smaug-72B-v0.1 with attention_bias
        # Support internlm/internlm-7b with bias
        attention_bias = getattr(config, "attention_bias", False) or getattr(
            config, "bias", False)
249
250
251
252
253
        bias_o_proj = attention_bias
        # support internlm/internlm3-8b with qkv_bias
        if hasattr(config, 'qkv_bias'):
            attention_bias = config.qkv_bias

Woosuk Kwon's avatar
Woosuk Kwon committed
254
        self.self_attn = LlamaAttention(
255
            config=config,
256
257
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
258
259
            num_kv_heads=getattr(config, "num_key_value_heads",
                                 config.num_attention_heads),
260
261
262
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
263
            quant_config=quant_config,
264
            bias=attention_bias,
265
            bias_o_proj=bias_o_proj,
266
            cache_config=cache_config,
267
            prefix=f"{prefix}.self_attn",
Woosuk Kwon's avatar
Woosuk Kwon committed
268
269
        )
        self.mlp = LlamaMLP(
270
271
272
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
273
            quant_config=quant_config,
274
            bias=getattr(config, "mlp_bias", False),
275
            prefix=f"{prefix}.mlp",
Woosuk Kwon's avatar
Woosuk Kwon committed
276
        )
277
278
279
280
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)
Woosuk Kwon's avatar
Woosuk Kwon committed
281
282
283

    def forward(
        self,
284
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
285
        hidden_states: torch.Tensor,
286
287
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
288
        # Self Attention
289
290
291
292
293
294
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
295
        hidden_states = self.self_attn(positions=positions,
296
                                       hidden_states=hidden_states)
Woosuk Kwon's avatar
Woosuk Kwon committed
297
298

        # Fully Connected
299
300
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
301
        hidden_states = self.mlp(hidden_states)
302
        return hidden_states, residual
Woosuk Kwon's avatar
Woosuk Kwon committed
303
304


305
@support_torch_compile
Woosuk Kwon's avatar
Woosuk Kwon committed
306
307
class LlamaModel(nn.Module):

308
309
310
311
312
    def __init__(self,
                 *,
                 vllm_config: VllmConfig,
                 prefix: str = "",
                 layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer):
Woosuk Kwon's avatar
Woosuk Kwon committed
313
        super().__init__()
314
315
316
317
318
319

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config

Woosuk Kwon's avatar
Woosuk Kwon committed
320
        self.config = config
321
        self.quant_config = quant_config
322
323
324
325
        lora_vocab = (lora_config.lora_extra_vocab_size *
                      (lora_config.max_loras or 1)) if lora_config else 0
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size
326
327
328
329
330
331
        if get_pp_group().is_first_rank or (config.tie_word_embeddings
                                            and get_pp_group().is_last_rank):
            self.embed_tokens = VocabParallelEmbedding(
                self.vocab_size,
                config.hidden_size,
                org_num_embeddings=config.vocab_size,
332
                quant_config=quant_config,
333
334
335
            )
        else:
            self.embed_tokens = PPMissingLayer()
336
        self.start_layer, self.end_layer, self.layers = make_layers(
337
            config.num_hidden_layers,
338
339
340
341
            lambda prefix: layer_type(config=config,
                                      cache_config=cache_config,
                                      quant_config=quant_config,
                                      prefix=prefix),
342
343
            prefix=f"{prefix}.layers",
        )
344
345
346
347
        if get_pp_group().is_last_rank:
            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        else:
            self.norm = PPMissingLayer()
Woosuk Kwon's avatar
Woosuk Kwon committed
348

349
350
351
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
zhuwenwen's avatar
zhuwenwen committed
352
353
354
355
356
        
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
gaoqiong's avatar
gaoqiong committed
357
358
            
        self.tritonsingleton= W8a8GetCacheJSON()      
zhuwenwen's avatar
zhuwenwen committed
359
360
361
362
363
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
        # self.use_lm_nn = os.environ.get('LM_NN') == '1'
        self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
        self.use_fa_pad = os.environ.get('FA_PAD') == '1'
        self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
zhuwenwen's avatar
zhuwenwen committed
364
        self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
365

366
367
368
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

Woosuk Kwon's avatar
Woosuk Kwon committed
369
370
    def forward(
        self,
371
        input_ids: Optional[torch.Tensor],
372
        positions: torch.Tensor,
373
        intermediate_tensors: Optional[IntermediateTensors],
374
        inputs_embeds: Optional[torch.Tensor] = None,
375
376
377
378
379
380
381
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
            residual = None
382
        else:
383
384
385
386
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

387
388
        for layer in self.layers[self.start_layer:self.end_layer]:
            hidden_states, residual = layer(positions, hidden_states, residual)
389
390
391
392
393
394
395

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })

396
        hidden_states, _ = self.norm(hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
397
398
        return hidden_states

399
400
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
401
402
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
403
404
405
406
407
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
Zhuohan Li's avatar
Zhuohan Li committed
408
        ]
409
        params_dict = dict(self.named_parameters())
410
        loaded_params: Set[str] = set()
411
        for name, loaded_weight in weights:
412
413
            if "rotary_emb.inv_freq" in name:
                continue
414
415
416
417
            if ("rotary_emb.cos_cached" in name
                    or "rotary_emb.sin_cached" in name):
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
418
                continue
419
420
            if (self.quant_config is not None and
                (scale_name := self.quant_config.get_cache_scale(name))):
421
                # Loading kv cache quantization scales
422
423
424
                param = params_dict[scale_name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
425
426
                loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
                                 loaded_weight[0])
427
                weight_loader(param, loaded_weight)
428
                loaded_params.add(scale_name)
429
                continue
430
431
432
433
434
            if "scale" in name:
                # Remapping the name of FP8 kv-scale.
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue
435
            for param_name, weight_name, shard_id in stacked_params_mapping:
Zhuohan Li's avatar
Zhuohan Li committed
436
                if weight_name not in name:
437
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
438
439
440
441
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
442
443
444
445

                if is_pp_missing_parameter(name, self):
                    continue

CHU Tianxiang's avatar
CHU Tianxiang committed
446
                param = params_dict[name]
447
448
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
449
                break
450
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
451
452
453
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
454
455
456
457

                if is_pp_missing_parameter(name, self):
                    continue

458
459
460
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
461
                weight_loader(param, loaded_weight)
462
            loaded_params.add(name)
gaoqiong's avatar
gaoqiong committed
463
            
464
        if self.use_llama_nn and self.quant_method is None :
gaoqiong's avatar
gaoqiong committed
465
466
467
468
            lay_key_words = [
                "self_attn.qkv_proj.weight",
                "self_attn.o_proj.weight",
                "mlp.gate_up_proj.weight",
469
                "mlp.down_proj.weight",
gaoqiong's avatar
gaoqiong committed
470
471
472
            ]
            combined_words = "|".join(lay_key_words)
            
zhuwenwen's avatar
zhuwenwen committed
473
474
            # lay_qkv_words = ["self_attn.qkv_proj.weight"]   
            # qkv_words = "|".join(lay_qkv_words)          
zhuwenwen's avatar
zhuwenwen committed
475
            
zhuwenwen's avatar
zhuwenwen committed
476
477
478
            # for layername, weight in params_dict.items():
            for layername in loaded_params:
                weight = params_dict[layername]
zhuwenwen's avatar
zhuwenwen committed
479
                if "lm_head.weight" in layername and weight.shape[1] >= 4096:
480
481
                    lay_key_words.append("lm_head.weight")
                    combined_words = "|".join(lay_key_words)
482
483
484
                    os.environ['LM_NN'] = '1'  
                else:
                    os.environ['LM_NN'] = '0' 
485
                    
gaoqiong's avatar
gaoqiong committed
486
                matches = re.findall(combined_words, layername)
487
                
488
                if matches:         
zhuwenwen's avatar
zhuwenwen committed
489
490
                    # if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
                    #     weight.data = pad_weight(weight.data, 32)  
491
                        
zhuwenwen's avatar
zhuwenwen committed
492
493
494
                    # if self.use_fa_pad and (re.findall(qkv_words, layername)):
                    #     if not gemm_bank_conf(weight.data.shape[0]):
                    #         weight.data = pad_weight(weight.data, 32)
495
                                 
gaoqiong's avatar
gaoqiong committed
496
497
498
                    _weight = torch.zeros_like(weight.data)
                    ori_shape =_weight.shape
                    
zhuwenwen's avatar
zhuwenwen committed
499
                    ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
gaoqiong's avatar
gaoqiong committed
500
501
                    weight.data.copy_(_weight)
                    
zhuwenwen's avatar
zhuwenwen committed
502
                    weight.data=weight.data.reshape(ori_shape[1], -1)
503
        else:
zhuwenwen's avatar
zhuwenwen committed
504
            os.environ['LM_NN'] = '0'
505
506
507
            os.environ['LLAMA_NN'] = '0'
            
        if self.quant_method == "awq" and not envs.VLLM_USE_TRITON_AWQ:
gaoqiong's avatar
gaoqiong committed
508
509
510
511
512
513
514
515
            lay_key_words = [
                "self_attn.qkv_proj.qweight",
                "self_attn.o_proj.qweight",
                "mlp.gate_up_proj.qweight",
                "mlp.down_proj.qweight"
            ]
            combined_words = "|".join(lay_key_words)
            
zhuwenwen's avatar
zhuwenwen committed
516
517
            for layername in loaded_params:
                weight = params_dict[layername]
gaoqiong's avatar
gaoqiong committed
518
519
520
521
522
523
524
525
526
527
528
529
530
531
                
                matches = re.findall(combined_words, layername)
                if matches:
                    qweight =params_dict[layername]
                    qzeros=params_dict[layername.replace("qweight", "qzeros")]
                    scales=params_dict[layername.replace("qweight", "scales")]
                    zeros_and_scalse =params_dict[layername.replace("qweight", "zeros_and_scales")]
                    
                    group_size= self.quant_config.group_size 
                   
                    dim_n = scales.data.shape[1]
                    dim_k = qweight.data.shape[0]
                    pad_group=2              
                    
gaoqiong's avatar
gaoqiong committed
532
                    _qw, _sz=ops.convert_s4(qweight,qzeros,scales,int(group_size)) 
gaoqiong's avatar
gaoqiong committed
533
                    
gaoqiong's avatar
gaoqiong committed
534
                    sz = ops.sz_permute(_sz).reshape(-1,dim_n)       
gaoqiong's avatar
gaoqiong committed
535
536
537
                    
                    zeros_and_scalse.data.copy_(sz)
                    qweight.data.copy_(_qw)
gaoqiong's avatar
gaoqiong committed
538
                    
gaoqiong's avatar
gaoqiong committed
539
540
541
542
                    #reshape
                    zeros_and_scalse.data=zeros_and_scalse.reshape(dim_n,-1)    #[k/greop_size,n]------>[n,k/group_size]
                    qweight.data=qweight.data.reshape(dim_n,-1)                      #[k,n/8]---->[n,k/8]  
                
zhuwenwen's avatar
zhuwenwen committed
543
                    if dim_k % 4096==0 and self.use_awq_pad:
gaoqiong's avatar
gaoqiong committed
544
545
546
547
                        zeros_and_scalse_pad= torch.zeros(dim_n,pad_group,dtype=torch.int32).cuda()
                        zeros_and_scalse.data=torch.cat((zeros_and_scalse.data,zeros_and_scalse_pad),dim=1).contiguous()
                        qweight_pad= torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda()
                        qweight.data=torch.cat((qweight.data,qweight_pad),dim=1).contiguous()
zhuwenwen's avatar
zhuwenwen committed
548
            
gaoqiong's avatar
gaoqiong committed
549
        #当为triton支持推理的时候不能进行处理
zhuwenwen's avatar
zhuwenwen committed
550
551
552
553
554
555
556
557
        if self.quant_method == "compressed_tensors":
            lay_key_words = [
                "self_attn.qkv_proj.weight",
                "self_attn.o_proj.weight",
                "mlp.gate_up_proj.weight",
                "mlp.down_proj.weight",
            ]
            combined_words = "|".join(lay_key_words)
gaoqiong's avatar
gaoqiong committed
558
559
            weight_shapes=[]
            all_json={}
560
            matched_key_words=set()
zhuwenwen's avatar
zhuwenwen committed
561
562
563
            
            for layername, weight in params_dict.items():  
                matches = re.findall(combined_words, layername)
gaoqiong's avatar
gaoqiong committed
564
                if matches and "scale" not in layername:
zhuwenwen's avatar
zhuwenwen committed
565
                    weight_data =params_dict[layername]
gaoqiong's avatar
gaoqiong committed
566
                    n=weight_data.shape[0]
zhuwenwen's avatar
zhuwenwen committed
567
568
569
570
                    # k=weight_data.shape[1]
                    
                    # #判断当前size是否在优化的范围内,假如存在则走triton,假如不存在则走rocblas
                    # json_file=self.tritonsingleton.get_w8a8json_name(n,k)
gaoqiong's avatar
gaoqiong committed
571
572
573
574
575
576
577
                    
                    #rocblas和cutlass目前都需要weight做处理,但是triton不用
                    if self.w8a8_strategy!=1:
                        _weight=weight_data.T.contiguous().reshape(n,-1)
                        weight_data.data.copy_(_weight)  
                    
                    #下面是针对模型记录模型出现k和n值 
578
579
                    elif len(matched_key_words) < 4 and matches[0] not in matched_key_words:
                        matched_key_words.add(matches[0])
gaoqiong's avatar
gaoqiong committed
580
581
582
583
584
                        k=weight_data.shape[1]
                        weight_shapes.append({n,k})
                
                        json_file=self.tritonsingleton.get_w8a8json_name(n,k)
                        configs_dict=self.tritonsingleton.get_triton_cache(json_file,n,k)
gaoqiong's avatar
gaoqiong committed
585
586
                        if configs_dict:
                            all_json.update(configs_dict)
gaoqiong's avatar
gaoqiong committed
587
588
589
590
591
592
593
594
                                              
            if self.w8a8_strategy==1:
                self.tritonsingleton.triton_json_dict.append(all_json)
                #找到的所有config都进行一次warmup
                for key, value in all_json.items():
                    m=int(key.split('_')[0])
                    n=int(key.split('_')[1])
                    k=int(key.split('_')[2])
gaoqiong's avatar
gaoqiong committed
595
                    ops.triton_int8_gemm_helper(m=m,n=n,k=k,per_token_act_quant=True,per_out_channel_weight_quant=True,use_bias=False,best_config=value)
zhuwenwen's avatar
zhuwenwen committed
596
                    
597
        return loaded_params
598

Woosuk Kwon's avatar
Woosuk Kwon committed
599

600
class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
Terry's avatar
Terry committed
601
    packed_modules_mapping = {
602
603
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"]
Terry's avatar
Terry committed
604
605
606
607
608
    }

    # LoRA specific attributes
    embedding_modules = {
        "embed_tokens": "input_embeddings",
609
        "lm_head": "output_embeddings"
Terry's avatar
Terry committed
610
611
    }
    embedding_padding_modules = ["lm_head"]
612

613
614
615
616
617
    # Mistral/Llama models can also be loaded with --load-format mistral
    # from consolidated.safetensors checkpoints
    mistral_mapping = {
        "layers": "model.layers",
        "attention": "self_attn",
618
619
620
        "qscale_act": "input_scale",
        "qscale_weight": "weight_scale",
        "kv_fake_quantizer.qscale_act": "kv_scale",
621
622
623
624
625
626
627
628
629
630
631
632
        "wq": "q_proj",
        "wk": "k_proj",
        "wv": "v_proj",
        "wo": "o_proj",
        "attention_norm": "input_layernorm",
        "feed_forward": "mlp",
        "w1": "gate_proj",
        "w2": "down_proj",
        "w3": "up_proj",
        "ffn_norm": "post_attention_layernorm",
        "tok_embeddings": "model.embed_tokens",
        "output": "lm_head",
633
        "norm": "model.norm",
634
    }
635

636
637
638
639
640
    def __init__(self,
                 *,
                 vllm_config: VllmConfig,
                 prefix: str = "",
                 layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer):
Woosuk Kwon's avatar
Woosuk Kwon committed
641
        super().__init__()
642
643
644
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
Woosuk Kwon's avatar
Woosuk Kwon committed
645
        self.config = config
646
        self.lora_config = lora_config
zhuwenwen's avatar
zhuwenwen committed
647
        
648
        self.model = self._init_model(vllm_config=vllm_config,
649
650
                                      prefix=maybe_prefix(prefix, "model"),
                                      layer_type=layer_type)
651

652
653
654
655
656
657
658
659
        if get_pp_group().is_last_rank:
            self.unpadded_vocab_size = config.vocab_size
            if lora_config:
                self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
            self.lm_head = ParallelLMHead(
                self.unpadded_vocab_size,
                config.hidden_size,
                org_num_embeddings=config.vocab_size,
660
661
662
663
664
665
                padding_size=(
                    DEFAULT_VOCAB_PADDING_SIZE
                    # We need bigger padding if using lora for kernel
                    # compatibility
                    if not lora_config else
                    lora_config.lora_vocab_padding_size),
666
                quant_config=quant_config,
667
                prefix=maybe_prefix(prefix, "lm_head"),
668
669
            )
            if config.tie_word_embeddings:
670
671
                self.lm_head = self.lm_head.tie_weights(
                    self.model.embed_tokens)
672
673
674
675
676
677
678

            logit_scale = getattr(config, "logit_scale", 1.0)
            self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                    config.vocab_size,
                                                    logit_scale)
        else:
            self.lm_head = PPMissingLayer()
679

680
681
        self.sampler = get_sampler()

682
683
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
zhuwenwen's avatar
zhuwenwen committed
684
        
Woosuk Kwon's avatar
Woosuk Kwon committed
685

686
687
688
689
690
691
692
    def _init_model(self,
                    vllm_config: VllmConfig,
                    prefix: str = "",
                    layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer):
        return LlamaModel(vllm_config=vllm_config,
                          prefix=prefix,
                          layer_type=layer_type)
693

694
695
696
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

Woosuk Kwon's avatar
Woosuk Kwon committed
697
698
    def forward(
        self,
699
700
        input_ids: torch.Tensor,
        positions: torch.Tensor,
701
        intermediate_tensors: Optional[IntermediateTensors] = None,
702
        inputs_embeds: Optional[torch.Tensor] = None,
703
    ) -> Union[torch.Tensor, IntermediateTensors]:
704
        model_output = self.model(input_ids, positions, intermediate_tensors,
705
                                  inputs_embeds)
706
        return model_output
707

708
709
710
711
712
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
713
        logits = self.logits_processor(self.lm_head, hidden_states,
714
715
716
                                       sampling_metadata)
        return logits

717
718
    def sample(self, logits: torch.Tensor,
               sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
719
        next_tokens = self.sampler(logits, sampling_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
720
721
        return next_tokens

zhuwenwen's avatar
zhuwenwen committed
722

723
724
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
725
726
727
728
729
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=(["lm_head."]
                           if self.config.tie_word_embeddings else None),
        )
730
        return loader.load_weights(
731
            self.maybe_remap_mistral(name, loaded_weight)
732
            for name, loaded_weight in weights)
733

zhuwenwen's avatar
zhuwenwen committed
734

735
736
737
    # This function is used to remap the mistral format as
    # used by Mistral and Llama <=2
    def maybe_remap_mistral(
738
739
740
741
        self,
        name: str,
        loaded_weight: torch.Tensor,
    ) -> Tuple[str, torch.Tensor]:
742

743
        def permute(w: torch.Tensor, n_heads: int):
744
745
746
747
748
749
750
751
752
753
            attn_in = self.config.head_dim * n_heads
            attn_out = self.config.hidden_size

            return w.view(n_heads, attn_in // n_heads // 2, 2,
                          attn_out).transpose(1, 2).reshape(attn_in, attn_out)

        mapping = self.mistral_mapping
        modules = name.split(".")

        # rotary embeds should be sliced
754
        if "wk" in modules and modules[-1] == "weight":
755
756
            loaded_weight = permute(loaded_weight,
                                    self.config.num_key_value_heads)
757
        elif "wq" in modules and modules[-1] == "weight":
758
759
760
            loaded_weight = permute(loaded_weight,
                                    self.config.num_attention_heads)

761
762
763
764
765
766
767
768
769
770
771
        num_modules = len(modules)
        for i in range(num_modules):
            item = modules[i]
            next_item = modules[i + 1] if i < num_modules - 1 else None

            combined_item = (f"{item}.{next_item}"
                             if next_item is not None else None)

            if combined_item in mapping:
                name = name.replace(combined_item, mapping[combined_item])
            elif item in mapping and mapping[item] not in name:
772
773
774
                name = name.replace(item, mapping[item])

        return name, loaded_weight