llama.py 32.1 KB
Newer Older
1
2
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
Woosuk Kwon's avatar
Woosuk Kwon committed
3
# Copyright 2023 The vLLM team.
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
22
"""Inference-only LLaMA model compatible with HuggingFace weights."""
23
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
24
25
26
27

import torch
from torch import nn
from transformers import LlamaConfig
zhuwenwen's avatar
zhuwenwen committed
28
import os
gaoqiong's avatar
gaoqiong committed
29
import re
Woosuk Kwon's avatar
Woosuk Kwon committed
30

31
from vllm.attention import Attention, AttentionMetadata
32
from vllm.compilation.decorators import support_torch_compile
33
from vllm.config import CacheConfig, VllmConfig
34
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
35
                              get_tensor_model_parallel_world_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
36
from vllm.model_executor.layers.activation import SiluAndMul
37
from vllm.model_executor.layers.layernorm import RMSNorm
38
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
39
40
                                               QKVParallelLinear,
                                               RowParallelLinear)
41
from vllm.model_executor.layers.logits_processor import LogitsProcessor
42
from vllm.model_executor.layers.quantization import QuantizationConfig
43
44
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
    get_compressed_tensors_cache_scale)
45
from vllm.model_executor.layers.rotary_embedding import get_rope
Joe Runde's avatar
Joe Runde committed
46
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
47
from vllm.model_executor.layers.vocab_parallel_embedding import (
48
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
49
from vllm.model_executor.model_loader.weight_utils import (
50
    default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
51
from vllm.model_executor.sampling_metadata import SamplingMetadata
52
from vllm.platforms import current_platform
53
from vllm.sequence import IntermediateTensors
zhuwenwen's avatar
zhuwenwen committed
54
from vllm.utils import W8a8GetCacheJSON
Woosuk Kwon's avatar
Woosuk Kwon committed
55

56
from .interfaces import SupportsLoRA, SupportsPP
57
58
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
                    is_pp_missing_parameter,
59
60
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
61

gaoqiong's avatar
gaoqiong committed
62
from vllm import _custom_ops as ops
63
64
from vllm.model_executor.utils import pad_weight, gemm_bank_conf

Woosuk Kwon's avatar
Woosuk Kwon committed
65
66

class LlamaMLP(nn.Module):
67

Woosuk Kwon's avatar
Woosuk Kwon committed
68
69
    def __init__(
        self,
70
71
72
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
73
        quant_config: Optional[QuantizationConfig] = None,
74
        bias: bool = False,
75
        prefix: str = "",
76
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
77
        super().__init__()
78
        self.gate_up_proj = MergedColumnParallelLinear(
79
80
            input_size=hidden_size,
            output_sizes=[intermediate_size] * 2,
81
            bias=bias,
82
            quant_config=quant_config,
83
84
85
86
87
88
89
90
91
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            input_size=intermediate_size,
            output_size=hidden_size,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.down_proj",
        )
92
93
94
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
Woosuk Kwon's avatar
Woosuk Kwon committed
95
        self.act_fn = SiluAndMul()
Woosuk Kwon's avatar
Woosuk Kwon committed
96
97

    def forward(self, x):
98
99
        x, _ = self.gate_up_proj(x)
        x = self.act_fn(x)
Woosuk Kwon's avatar
Woosuk Kwon committed
100
101
102
103
104
105
106
107
        x, _ = self.down_proj(x)
        return x


class LlamaAttention(nn.Module):

    def __init__(
        self,
108
        config: LlamaConfig,
109
110
111
112
113
114
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
        max_position_embeddings: int = 8192,
115
        quant_config: Optional[QuantizationConfig] = None,
116
        bias: bool = False,
117
        cache_config: Optional[CacheConfig] = None,
118
        prefix: str = "",
119
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
120
        super().__init__()
121
        layer_idx = extract_layer_index(prefix)
122
        self.hidden_size = hidden_size
Zhuohan Li's avatar
Zhuohan Li committed
123
        tp_size = get_tensor_model_parallel_world_size()
124
        self.total_num_heads = num_heads
Zhuohan Li's avatar
Zhuohan Li committed
125
126
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
127
        self.total_num_kv_heads = num_kv_heads
128
129
130
131
132
133
134
135
136
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
137
138
139
        # MistralConfig has an optional head_dim introduced by Mistral-Nemo
        self.head_dim = getattr(config, "head_dim",
                                self.hidden_size // self.total_num_heads)
Zhuohan Li's avatar
Zhuohan Li committed
140
141
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
142
        self.scaling = self.head_dim**-0.5
143
144
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings
Woosuk Kwon's avatar
Woosuk Kwon committed
145

146
        self.qkv_proj = QKVParallelLinear(
147
148
149
150
            hidden_size=hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_kv_heads,
151
            bias=bias,
152
            quant_config=quant_config,
153
            prefix=f"{prefix}.qkv_proj",
Woosuk Kwon's avatar
Woosuk Kwon committed
154
        )
155

156
        self.o_proj = RowParallelLinear(
157
158
            input_size=self.total_num_heads * self.head_dim,
            output_size=hidden_size,
159
            bias=bias,
160
            quant_config=quant_config,
161
            prefix=f"{prefix}.o_proj",
Woosuk Kwon's avatar
Woosuk Kwon committed
162
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
163

164
165
166
167
        is_neox_style = True
        if quant_config is not None and quant_config.get_name() == "gguf":
            is_neox_style = False

168
169
170
171
172
173
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
174
            is_neox_style=is_neox_style,
175
        )
176
177

        if hasattr(config, "interleaved_sliding_window"):
178
179
180
181
182
183
            interleaved_sliding_window = config.interleaved_sliding_window
            if isinstance(interleaved_sliding_window, int):
                sliding_window = interleaved_sliding_window
            elif isinstance(interleaved_sliding_window, list):
                sw_idx = layer_idx % len(interleaved_sliding_window)
                sliding_window = interleaved_sliding_window[sw_idx]
184
            else:
185
186
                raise ValueError(
                    f"{type(interleaved_sliding_window)} is not supported.")
187
188
189
        else:
            sliding_window = None

190
191
192
193
194
195
196
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
197
            per_layer_sliding_window=sliding_window,
198
            prefix=f"{prefix}.attn",
199
        )
200
201
202
203
204
        
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
Woosuk Kwon's avatar
Woosuk Kwon committed
205
206
207

    def forward(
        self,
208
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
209
        hidden_states: torch.Tensor,
210
211
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
212
    ) -> torch.Tensor:
213
        qkv, _ = self.qkv_proj(hidden_states)
zhuwenwen's avatar
zhuwenwen committed
214
215
        # if os.environ.get('FA_PAD') == '1' and self.quant_method is None:
        #     qkv = qkv[...,:-32]
Zhuohan Li's avatar
Zhuohan Li committed
216
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
217
        q, k = self.rotary_emb(positions, q, k)
218
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
219
220
221
222
223
224
        output, _ = self.o_proj(attn_output)
        return output


class LlamaDecoderLayer(nn.Module):

225
226
227
    def __init__(
        self,
        config: LlamaConfig,
228
        cache_config: Optional[CacheConfig] = None,
229
        quant_config: Optional[QuantizationConfig] = None,
230
        prefix: str = "",
231
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
232
233
        super().__init__()
        self.hidden_size = config.hidden_size
234
235
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
236
237
238
239
        if rope_scaling is not None and getattr(
                config, "original_max_position_embeddings", None):
            rope_scaling["original_max_position_embeddings"] = (
                config.original_max_position_embeddings)
240
241
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
242
243
244
245
        # Support abacusai/Smaug-72B-v0.1 with attention_bias
        # Support internlm/internlm-7b with bias
        attention_bias = getattr(config, "attention_bias", False) or getattr(
            config, "bias", False)
Woosuk Kwon's avatar
Woosuk Kwon committed
246
        self.self_attn = LlamaAttention(
247
            config=config,
248
249
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
250
251
            num_kv_heads=getattr(config, "num_key_value_heads",
                                 config.num_attention_heads),
252
253
254
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
255
            quant_config=quant_config,
256
            bias=attention_bias,
257
            cache_config=cache_config,
258
            prefix=f"{prefix}.self_attn",
Woosuk Kwon's avatar
Woosuk Kwon committed
259
260
        )
        self.mlp = LlamaMLP(
261
262
263
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
264
            quant_config=quant_config,
265
            bias=getattr(config, "mlp_bias", False),
266
            prefix=f"{prefix}.mlp",
Woosuk Kwon's avatar
Woosuk Kwon committed
267
        )
268
269
270
271
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)
Woosuk Kwon's avatar
Woosuk Kwon committed
272
273
274

    def forward(
        self,
275
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
276
        hidden_states: torch.Tensor,
277
278
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
279
280
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
281
        # Self Attention
282
283
284
285
286
287
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
288
289
290
291
        hidden_states = self.self_attn(positions=positions,
                                       hidden_states=hidden_states,
                                       kv_cache=kv_cache,
                                       attn_metadata=attn_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
292
293

        # Fully Connected
294
295
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
296
        hidden_states = self.mlp(hidden_states)
297
        return hidden_states, residual
Woosuk Kwon's avatar
Woosuk Kwon committed
298
299


300
@support_torch_compile
Woosuk Kwon's avatar
Woosuk Kwon committed
301
302
class LlamaModel(nn.Module):

303
304
305
306
307
    def __init__(self,
                 *,
                 vllm_config: VllmConfig,
                 prefix: str = "",
                 layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer):
Woosuk Kwon's avatar
Woosuk Kwon committed
308
        super().__init__()
309
310
311
312
313
314

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config

Woosuk Kwon's avatar
Woosuk Kwon committed
315
316
        self.config = config
        self.padding_idx = config.pad_token_id
317
318
319
320
        lora_vocab = (lora_config.lora_extra_vocab_size *
                      (lora_config.max_loras or 1)) if lora_config else 0
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size
321
322
323
324
325
326
        if get_pp_group().is_first_rank or (config.tie_word_embeddings
                                            and get_pp_group().is_last_rank):
            self.embed_tokens = VocabParallelEmbedding(
                self.vocab_size,
                config.hidden_size,
                org_num_embeddings=config.vocab_size,
327
                quant_config=quant_config,
328
329
330
            )
        else:
            self.embed_tokens = PPMissingLayer()
331
        self.start_layer, self.end_layer, self.layers = make_layers(
332
            config.num_hidden_layers,
333
334
335
336
            lambda prefix: layer_type(config=config,
                                      cache_config=cache_config,
                                      quant_config=quant_config,
                                      prefix=prefix),
337
338
            prefix=f"{prefix}.layers",
        )
339
340
341
342
        if get_pp_group().is_last_rank:
            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        else:
            self.norm = PPMissingLayer()
Woosuk Kwon's avatar
Woosuk Kwon committed
343

344
345
346
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
zhuwenwen's avatar
zhuwenwen committed
347
348
349
350
351
352
353
354
355
356
357
358
        
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
              
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
        # self.use_lm_nn = os.environ.get('LM_NN') == '1'
        self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
        self.use_fa_pad = os.environ.get('FA_PAD') == '1'
        self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
        self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '0'))
359

360
361
362
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

Woosuk Kwon's avatar
Woosuk Kwon committed
363
364
    def forward(
        self,
365
        input_ids: Optional[torch.Tensor],
366
        positions: torch.Tensor,
367
368
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
369
        intermediate_tensors: Optional[IntermediateTensors],
370
        inputs_embeds: Optional[torch.Tensor] = None,
371
372
373
374
375
376
377
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
            residual = None
378
        else:
379
380
381
382
383
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

        for i in range(self.start_layer, self.end_layer):
Woosuk Kwon's avatar
Woosuk Kwon committed
384
            layer = self.layers[i]
385
386
387
            hidden_states, residual = layer(positions, hidden_states,
                                            kv_caches[i - self.start_layer],
                                            attn_metadata, residual)
388
389
390
391
392
393
394

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })

395
        hidden_states, _ = self.norm(hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
396
397
        return hidden_states

398
399
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
400
401
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
402
403
404
405
406
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
Zhuohan Li's avatar
Zhuohan Li committed
407
        ]
408
        params_dict = dict(self.named_parameters())
409
        loaded_params: Set[str] = set()
410
        for name, loaded_weight in weights:
411
412
            if "rotary_emb.inv_freq" in name:
                continue
413
414
415
416
            if ("rotary_emb.cos_cached" in name
                    or "rotary_emb.sin_cached" in name):
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
417
                continue
418
419
420
421
422
423
424
            if scale_name := get_compressed_tensors_cache_scale(name):
                # Loading kv cache scales for compressed-tensors quantization
                param = params_dict[scale_name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                loaded_weight = loaded_weight[0]
                weight_loader(param, loaded_weight)
425
                loaded_params.add(scale_name)
426
                continue
427
            for param_name, weight_name, shard_id in stacked_params_mapping:
Zhuohan Li's avatar
Zhuohan Li committed
428
                if weight_name not in name:
429
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
430
431
432
433
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
434
435
436
437

                if is_pp_missing_parameter(name, self):
                    continue

CHU Tianxiang's avatar
CHU Tianxiang committed
438
                param = params_dict[name]
439
440
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
441
                break
442
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
443
444
445
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
446
                # Remapping the name of FP8 kv-scale.
447
448
449
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue
450
451
452
453

                if is_pp_missing_parameter(name, self):
                    continue

454
455
456
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
457
                weight_loader(param, loaded_weight)
458
            loaded_params.add(name)
gaoqiong's avatar
gaoqiong committed
459
            
460
        if self.use_llama_nn and self.quant_method is None :
gaoqiong's avatar
gaoqiong committed
461
462
463
464
            lay_key_words = [
                "self_attn.qkv_proj.weight",
                "self_attn.o_proj.weight",
                "mlp.gate_up_proj.weight",
465
                "mlp.down_proj.weight",
gaoqiong's avatar
gaoqiong committed
466
467
468
            ]
            combined_words = "|".join(lay_key_words)
            
zhuwenwen's avatar
zhuwenwen committed
469
470
            # lay_qkv_words = ["self_attn.qkv_proj.weight"]   
            # qkv_words = "|".join(lay_qkv_words)          
zhuwenwen's avatar
zhuwenwen committed
471
            
gaoqiong's avatar
gaoqiong committed
472
            for layername, weight in params_dict.items():
zhuwenwen's avatar
zhuwenwen committed
473
                if "lm_head.weight" in layername and weight.shape[1] >= 4096:
474
475
                    lay_key_words.append("lm_head.weight")
                    combined_words = "|".join(lay_key_words)
476
477
478
                    os.environ['LM_NN'] = '1'  
                else:
                    os.environ['LM_NN'] = '0' 
479
                    
gaoqiong's avatar
gaoqiong committed
480
                matches = re.findall(combined_words, layername)
481
                
482
                if matches:         
zhuwenwen's avatar
zhuwenwen committed
483
484
                    # if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
                    #     weight.data = pad_weight(weight.data, 32)  
485
                        
zhuwenwen's avatar
zhuwenwen committed
486
487
488
                    # if self.use_fa_pad and (re.findall(qkv_words, layername)):
                    #     if not gemm_bank_conf(weight.data.shape[0]):
                    #         weight.data = pad_weight(weight.data, 32)
489
                                 
gaoqiong's avatar
gaoqiong committed
490
491
492
                    _weight = torch.zeros_like(weight.data)
                    ori_shape =_weight.shape
                    
zhuwenwen's avatar
zhuwenwen committed
493
                    ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
gaoqiong's avatar
gaoqiong committed
494
495
                    weight.data.copy_(_weight)
                    
zhuwenwen's avatar
zhuwenwen committed
496
                    weight.data=weight.data.reshape(ori_shape[1], -1)
gaoqiong's avatar
gaoqiong committed
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
     
        if self.quant_method == "awq":
            lay_key_words = [
                "self_attn.qkv_proj.qweight",
                "self_attn.o_proj.qweight",
                "mlp.gate_up_proj.qweight",
                "mlp.down_proj.qweight"
            ]
            combined_words = "|".join(lay_key_words)
            
            for layername, weight in params_dict.items():
                
                matches = re.findall(combined_words, layername)
                if matches:
                    qweight =params_dict[layername]
                    qzeros=params_dict[layername.replace("qweight", "qzeros")]
                    scales=params_dict[layername.replace("qweight", "scales")]
                    zeros_and_scalse =params_dict[layername.replace("qweight", "zeros_and_scales")]
                    
                    group_size= self.quant_config.group_size 
                   
                    dim_n = scales.data.shape[1]
                    dim_k = qweight.data.shape[0]
                    pad_group=2              
                    
gaoqiong's avatar
gaoqiong committed
522
                    _qw, _sz=ops.convert_s4(qweight,qzeros,scales,int(group_size)) 
gaoqiong's avatar
gaoqiong committed
523
                    
gaoqiong's avatar
gaoqiong committed
524
                    sz = ops.sz_permute(_sz).reshape(-1,dim_n)       
gaoqiong's avatar
gaoqiong committed
525
526
527
                    
                    zeros_and_scalse.data.copy_(sz)
                    qweight.data.copy_(_qw)
gaoqiong's avatar
gaoqiong committed
528
                    
gaoqiong's avatar
gaoqiong committed
529
530
531
532
                    #reshape
                    zeros_and_scalse.data=zeros_and_scalse.reshape(dim_n,-1)    #[k/greop_size,n]------>[n,k/group_size]
                    qweight.data=qweight.data.reshape(dim_n,-1)                      #[k,n/8]---->[n,k/8]  
                
zhuwenwen's avatar
zhuwenwen committed
533
                    if dim_k % 4096==0 and self.use_awq_pad:
gaoqiong's avatar
gaoqiong committed
534
535
536
537
                        zeros_and_scalse_pad= torch.zeros(dim_n,pad_group,dtype=torch.int32).cuda()
                        zeros_and_scalse.data=torch.cat((zeros_and_scalse.data,zeros_and_scalse_pad),dim=1).contiguous()
                        qweight_pad= torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda()
                        qweight.data=torch.cat((qweight.data,qweight_pad),dim=1).contiguous()
zhuwenwen's avatar
zhuwenwen committed
538
            
gaoqiong's avatar
gaoqiong committed
539
        #当为triton支持推理的时候不能进行处理
zhuwenwen's avatar
zhuwenwen committed
540
541
542
543
544
545
546
547
        if self.quant_method == "compressed_tensors":
            lay_key_words = [
                "self_attn.qkv_proj.weight",
                "self_attn.o_proj.weight",
                "mlp.gate_up_proj.weight",
                "mlp.down_proj.weight",
            ]
            combined_words = "|".join(lay_key_words)
gaoqiong's avatar
gaoqiong committed
548
549
            weight_shapes=[]
            all_json={}
zhuwenwen's avatar
zhuwenwen committed
550
551
552
            
            for layername, weight in params_dict.items():  
                matches = re.findall(combined_words, layername)
gaoqiong's avatar
gaoqiong committed
553
                if matches and "scale" not in layername:
zhuwenwen's avatar
zhuwenwen committed
554
                    weight_data =params_dict[layername]
gaoqiong's avatar
gaoqiong committed
555
556
557
558
559
560
561
562
563
564
565
566
567
568
                    n=weight_data.shape[0]
                    
                    #rocblas和cutlass目前都需要weight做处理,但是triton不用
                    if self.w8a8_strategy!=1:
                        _weight=weight_data.T.contiguous().reshape(n,-1)
                        weight_data.data.copy_(_weight)  
                    
                    #下面是针对模型记录模型出现k和n值 
                    elif len(weight_shapes)<4: 
                        k=weight_data.shape[1]
                        weight_shapes.append({n,k})
                
                        json_file=self.tritonsingleton.get_w8a8json_name(n,k)
                        configs_dict=self.tritonsingleton.get_triton_cache(json_file,n,k)
gaoqiong's avatar
gaoqiong committed
569
570
                        if configs_dict:
                            all_json.update(configs_dict)
gaoqiong's avatar
gaoqiong committed
571
572
573
574
575
576
577
578
                                              
            if self.w8a8_strategy==1:
                self.tritonsingleton.triton_json_dict.append(all_json)
                #找到的所有config都进行一次warmup
                for key, value in all_json.items():
                    m=int(key.split('_')[0])
                    n=int(key.split('_')[1])
                    k=int(key.split('_')[2])
gaoqiong's avatar
gaoqiong committed
579
                    ops.triton_int8_gemm_helper(m=m,n=n,k=k,per_token_act_quant=True,per_out_channel_weight_quant=True,use_bias=False,best_config=value)
zhuwenwen's avatar
zhuwenwen committed
580
                    
581
        return loaded_params
582

583
584
585
586
587
588
589
590
591
592
    # If this function is called, it should always initialize KV cache scale
    # factors (or else raise an exception). Thus, handled exceptions should
    # make sure to leave KV cache scale factors in a known good (dummy) state
    def load_kv_cache_scales(self, quantization_param_path: str) -> None:
        tp_size = get_tensor_model_parallel_world_size()
        tp_rank = get_tensor_model_parallel_rank()
        for layer_idx, scaling_factor in kv_cache_scales_loader(
                quantization_param_path, tp_rank, tp_size,
                self.config.num_hidden_layers,
                self.config.__class__.model_type):
593
594
            if not isinstance(self.layers[layer_idx], nn.Identity):
                layer_self_attn = self.layers[layer_idx].self_attn
595

596
            if current_platform.is_rocm():
597
598
599
600
601
602
                # The scaling factor convention we are assuming is
                # quantized_value * scaling_factor ~= true_value
                # which is consistent with the practice of setting
                # scaling_factor = tensor_amax / FPtype_max
                scaling_factor *= 2
            if hasattr(layer_self_attn, "kv_scale"):
603
                layer_self_attn.attn._kv_scale = scaling_factor
604
605
606
            else:
                raise RuntimeError("Self attention has no KV cache scaling "
                                   "factor attribute!")
607

Woosuk Kwon's avatar
Woosuk Kwon committed
608

609
class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
Terry's avatar
Terry committed
610
    packed_modules_mapping = {
611
612
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"]
Terry's avatar
Terry committed
613
614
615
616
    }

    # LoRA specific attributes
    supported_lora_modules = [
617
618
        "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens",
        "lm_head"
Terry's avatar
Terry committed
619
620
621
    ]
    embedding_modules = {
        "embed_tokens": "input_embeddings",
622
        "lm_head": "output_embeddings"
Terry's avatar
Terry committed
623
624
    }
    embedding_padding_modules = ["lm_head"]
625
626

    # BitandBytes specific attributes
627
628
629
630
631
632
633
634
    bitsandbytes_stacked_params_mapping = {
        # shard_name, weight_name, index
        "q_proj": ("qkv_proj", 0),
        "k_proj": ("qkv_proj", 1),
        "v_proj": ("qkv_proj", 2),
        "gate_proj": ("gate_up_proj", 0),
        "up_proj": ("gate_up_proj", 1),
    }
635

636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
    # Mistral/Llama models can also be loaded with --load-format mistral
    # from consolidated.safetensors checkpoints
    mistral_mapping = {
        "layers": "model.layers",
        "attention": "self_attn",
        "wq": "q_proj",
        "wk": "k_proj",
        "wv": "v_proj",
        "wo": "o_proj",
        "attention_norm": "input_layernorm",
        "feed_forward": "mlp",
        "w1": "gate_proj",
        "w2": "down_proj",
        "w3": "up_proj",
        "ffn_norm": "post_attention_layernorm",
        "tok_embeddings": "model.embed_tokens",
        "output": "lm_head",
        "norm": "model.norm"
    }
655

656
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Woosuk Kwon's avatar
Woosuk Kwon committed
657
        super().__init__()
658
659
660
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
Woosuk Kwon's avatar
Woosuk Kwon committed
661
        self.config = config
662
        self.lora_config = lora_config
zhuwenwen's avatar
zhuwenwen committed
663
        
664
665
        self.model = self._init_model(vllm_config=vllm_config,
                                      prefix=maybe_prefix(prefix, "model"))
zhuwenwen's avatar
zhuwenwen committed
666
667
        
        self.tritonsingleton= W8a8GetCacheJSON()
668

669
670
671
672
673
674
675
676
        if get_pp_group().is_last_rank:
            self.unpadded_vocab_size = config.vocab_size
            if lora_config:
                self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
            self.lm_head = ParallelLMHead(
                self.unpadded_vocab_size,
                config.hidden_size,
                org_num_embeddings=config.vocab_size,
677
678
679
680
681
682
                padding_size=(
                    DEFAULT_VOCAB_PADDING_SIZE
                    # We need bigger padding if using lora for kernel
                    # compatibility
                    if not lora_config else
                    lora_config.lora_vocab_padding_size),
683
                quant_config=quant_config,
684
                prefix=maybe_prefix(prefix, "lm_head"),
685
686
            )
            if config.tie_word_embeddings:
687
688
                self.lm_head = self.lm_head.tie_weights(
                    self.model.embed_tokens)
689
690
691
692
693
694
695

            logit_scale = getattr(config, "logit_scale", 1.0)
            self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                    config.vocab_size,
                                                    logit_scale)
        else:
            self.lm_head = PPMissingLayer()
696

697
698
        self.sampler = get_sampler()

699
700
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
zhuwenwen's avatar
zhuwenwen committed
701
        
Woosuk Kwon's avatar
Woosuk Kwon committed
702

703
704
705
    def _init_model(self, vllm_config: VllmConfig, prefix: str = ""):
        return LlamaModel(vllm_config=vllm_config, prefix=prefix)

706
707
708
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

Woosuk Kwon's avatar
Woosuk Kwon committed
709
710
    def forward(
        self,
711
712
        input_ids: torch.Tensor,
        positions: torch.Tensor,
713
714
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
715
        intermediate_tensors: Optional[IntermediateTensors] = None,
716
        inputs_embeds: Optional[torch.Tensor] = None,
717
718
    ) -> Union[torch.Tensor, IntermediateTensors]:
        model_output = self.model(input_ids, positions, kv_caches,
719
720
                                  attn_metadata, intermediate_tensors,
                                  inputs_embeds)
721
        return model_output
722

723
724
725
726
727
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
728
        logits = self.logits_processor(self.lm_head, hidden_states,
729
730
731
                                       sampling_metadata)
        return logits

732
733
    def sample(self, logits: torch.Tensor,
               sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
734
        next_tokens = self.sampler(logits, sampling_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
735
736
        return next_tokens

zhuwenwen's avatar
zhuwenwen committed
737

738
739
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
740
741
742
743
744
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=(["lm_head."]
                           if self.config.tie_word_embeddings else None),
        )
745
        return loader.load_weights(
746
            self.maybe_remap_mistral(name, loaded_weight)
747
            for name, loaded_weight in weights)
748

zhuwenwen's avatar
zhuwenwen committed
749

750
    def load_kv_cache_scales(self, quantization_param_path: str) -> None:
751
        self.model.load_kv_cache_scales(quantization_param_path)
752
753
754
755

    # This function is used to remap the mistral format as
    # used by Mistral and Llama <=2
    def maybe_remap_mistral(
756
757
758
759
        self,
        name: str,
        loaded_weight: torch.Tensor,
    ) -> Tuple[str, torch.Tensor]:
760

761
        def permute(w: torch.Tensor, n_heads: int):
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
            attn_in = self.config.head_dim * n_heads
            attn_out = self.config.hidden_size

            return w.view(n_heads, attn_in // n_heads // 2, 2,
                          attn_out).transpose(1, 2).reshape(attn_in, attn_out)

        mapping = self.mistral_mapping
        modules = name.split(".")

        # rotary embeds should be sliced
        if "wk" in modules:
            loaded_weight = permute(loaded_weight,
                                    self.config.num_key_value_heads)
        elif "wq" in modules:
            loaded_weight = permute(loaded_weight,
                                    self.config.num_attention_heads)

        for item in modules:
            if item in mapping and mapping[item] not in name:
                name = name.replace(item, mapping[item])

        return name, loaded_weight