llama.py 30.9 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
Woosuk Kwon's avatar
Woosuk Kwon committed
5
# Copyright 2023 The vLLM team.
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
24
"""Inference-only LLaMA model compatible with HuggingFace weights."""
25
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
26
27
28
29

import torch
from torch import nn
from transformers import LlamaConfig
zhuwenwen's avatar
zhuwenwen committed
30
import os
gaoqiong's avatar
gaoqiong committed
31
import re
Woosuk Kwon's avatar
Woosuk Kwon committed
32

33
from vllm.attention import Attention, AttentionMetadata
34
from vllm.compilation.decorators import support_torch_compile
35
from vllm.config import CacheConfig, VllmConfig
36
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
Woosuk Kwon's avatar
Woosuk Kwon committed
37
from vllm.model_executor.layers.activation import SiluAndMul
38
from vllm.model_executor.layers.layernorm import RMSNorm
39
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
40
41
                                               QKVParallelLinear,
                                               RowParallelLinear)
42
from vllm.model_executor.layers.logits_processor import LogitsProcessor
43
from vllm.model_executor.layers.quantization import QuantizationConfig
44
from vllm.model_executor.layers.rotary_embedding import get_rope
Joe Runde's avatar
Joe Runde committed
45
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
46
from vllm.model_executor.layers.vocab_parallel_embedding import (
47
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
48
from vllm.model_executor.model_loader.weight_utils import (
49
    default_weight_loader, maybe_remap_kv_scale_name)
50
from vllm.model_executor.sampling_metadata import SamplingMetadata
51
from vllm.sequence import IntermediateTensors
zhuwenwen's avatar
zhuwenwen committed
52
from vllm.utils import W8a8GetCacheJSON
Woosuk Kwon's avatar
Woosuk Kwon committed
53

54
from .interfaces import SupportsLoRA, SupportsPP
55
56
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
                    is_pp_missing_parameter,
57
58
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
59

gaoqiong's avatar
gaoqiong committed
60
from vllm import _custom_ops as ops
61
62
from vllm.model_executor.utils import pad_weight, gemm_bank_conf

Woosuk Kwon's avatar
Woosuk Kwon committed
63
64

class LlamaMLP(nn.Module):
65

Woosuk Kwon's avatar
Woosuk Kwon committed
66
67
    def __init__(
        self,
68
69
70
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
71
        quant_config: Optional[QuantizationConfig] = None,
72
        bias: bool = False,
73
        prefix: str = "",
74
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
75
        super().__init__()
76
        self.gate_up_proj = MergedColumnParallelLinear(
77
78
            input_size=hidden_size,
            output_sizes=[intermediate_size] * 2,
79
            bias=bias,
80
            quant_config=quant_config,
81
82
83
84
85
86
87
88
89
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            input_size=intermediate_size,
            output_size=hidden_size,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.down_proj",
        )
90
91
92
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
Woosuk Kwon's avatar
Woosuk Kwon committed
93
        self.act_fn = SiluAndMul()
Woosuk Kwon's avatar
Woosuk Kwon committed
94
95

    def forward(self, x):
96
97
        x, _ = self.gate_up_proj(x)
        x = self.act_fn(x)
Woosuk Kwon's avatar
Woosuk Kwon committed
98
99
100
101
102
103
        x, _ = self.down_proj(x)
        return x


class LlamaAttention(nn.Module):

104
105
106
107
108
109
110
111
112
113
    def __init__(self,
                 config: LlamaConfig,
                 hidden_size: int,
                 num_heads: int,
                 num_kv_heads: int,
                 rope_theta: float = 10000,
                 rope_scaling: Optional[Dict[str, Any]] = None,
                 max_position_embeddings: int = 8192,
                 quant_config: Optional[QuantizationConfig] = None,
                 bias: bool = False,
114
                 bias_o_proj: bool = False,
115
                 cache_config: Optional[CacheConfig] = None,
116
                 prefix: str = "") -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
117
        super().__init__()
118
        layer_idx = extract_layer_index(prefix)
119
        self.hidden_size = hidden_size
Zhuohan Li's avatar
Zhuohan Li committed
120
        tp_size = get_tensor_model_parallel_world_size()
121
        self.total_num_heads = num_heads
Zhuohan Li's avatar
Zhuohan Li committed
122
123
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
124
        self.total_num_kv_heads = num_kv_heads
125
126
127
128
129
130
131
132
133
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
134
135
136
        # MistralConfig has an optional head_dim introduced by Mistral-Nemo
        self.head_dim = getattr(config, "head_dim",
                                self.hidden_size // self.total_num_heads)
Zhuohan Li's avatar
Zhuohan Li committed
137
138
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
139
        self.scaling = self.head_dim**-0.5
140
141
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings
Woosuk Kwon's avatar
Woosuk Kwon committed
142

143
        self.qkv_proj = QKVParallelLinear(
144
145
146
147
            hidden_size=hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_kv_heads,
148
            bias=bias,
149
            quant_config=quant_config,
150
            prefix=f"{prefix}.qkv_proj",
Woosuk Kwon's avatar
Woosuk Kwon committed
151
        )
152

153
        self.o_proj = RowParallelLinear(
154
155
            input_size=self.total_num_heads * self.head_dim,
            output_size=hidden_size,
156
            bias=bias_o_proj,
157
            quant_config=quant_config,
158
            prefix=f"{prefix}.o_proj",
Woosuk Kwon's avatar
Woosuk Kwon committed
159
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
160

161
        is_neox_style = True
162
163
        is_gguf = quant_config and quant_config.get_name() == "gguf"
        if is_gguf and config.model_type == "llama":
164
165
            is_neox_style = False

166
167
168
169
170
171
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
172
            is_neox_style=is_neox_style,
173
        )
174
175

        if hasattr(config, "interleaved_sliding_window"):
176
177
178
179
180
181
            interleaved_sliding_window = config.interleaved_sliding_window
            if isinstance(interleaved_sliding_window, int):
                sliding_window = interleaved_sliding_window
            elif isinstance(interleaved_sliding_window, list):
                sw_idx = layer_idx % len(interleaved_sliding_window)
                sliding_window = interleaved_sliding_window[sw_idx]
182
            else:
183
184
                raise ValueError(
                    f"{type(interleaved_sliding_window)} is not supported.")
185
186
187
        else:
            sliding_window = None

188
189
190
191
192
193
194
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
195
            per_layer_sliding_window=sliding_window,
196
            prefix=f"{prefix}.attn",
197
        )
198
199
200
201
202
        
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
Woosuk Kwon's avatar
Woosuk Kwon committed
203
204
205

    def forward(
        self,
206
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
207
        hidden_states: torch.Tensor,
208
209
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
210
    ) -> torch.Tensor:
211
        qkv, _ = self.qkv_proj(hidden_states)
zhuwenwen's avatar
zhuwenwen committed
212
213
        # if os.environ.get('FA_PAD') == '1' and self.quant_method is None:
        #     qkv = qkv[...,:-32]
Zhuohan Li's avatar
Zhuohan Li committed
214
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
215
        q, k = self.rotary_emb(positions, q, k)
216
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
217
218
219
220
221
222
        output, _ = self.o_proj(attn_output)
        return output


class LlamaDecoderLayer(nn.Module):

223
224
225
    def __init__(
        self,
        config: LlamaConfig,
226
        cache_config: Optional[CacheConfig] = None,
227
        quant_config: Optional[QuantizationConfig] = None,
228
        prefix: str = "",
229
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
230
231
        super().__init__()
        self.hidden_size = config.hidden_size
232
233
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
234
235
236
237
        if rope_scaling is not None and getattr(
                config, "original_max_position_embeddings", None):
            rope_scaling["original_max_position_embeddings"] = (
                config.original_max_position_embeddings)
238
239
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
240
241
242
243
        # Support abacusai/Smaug-72B-v0.1 with attention_bias
        # Support internlm/internlm-7b with bias
        attention_bias = getattr(config, "attention_bias", False) or getattr(
            config, "bias", False)
244
245
246
247
248
        bias_o_proj = attention_bias
        # support internlm/internlm3-8b with qkv_bias
        if hasattr(config, 'qkv_bias'):
            attention_bias = config.qkv_bias

Woosuk Kwon's avatar
Woosuk Kwon committed
249
        self.self_attn = LlamaAttention(
250
            config=config,
251
252
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
253
254
            num_kv_heads=getattr(config, "num_key_value_heads",
                                 config.num_attention_heads),
255
256
257
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
258
            quant_config=quant_config,
259
            bias=attention_bias,
260
            bias_o_proj=bias_o_proj,
261
            cache_config=cache_config,
262
            prefix=f"{prefix}.self_attn",
Woosuk Kwon's avatar
Woosuk Kwon committed
263
264
        )
        self.mlp = LlamaMLP(
265
266
267
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
268
            quant_config=quant_config,
269
            bias=getattr(config, "mlp_bias", False),
270
            prefix=f"{prefix}.mlp",
Woosuk Kwon's avatar
Woosuk Kwon committed
271
        )
272
273
274
275
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)
Woosuk Kwon's avatar
Woosuk Kwon committed
276
277
278

    def forward(
        self,
279
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
280
        hidden_states: torch.Tensor,
281
282
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
283
284
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
285
        # Self Attention
286
287
288
289
290
291
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
292
293
294
295
        hidden_states = self.self_attn(positions=positions,
                                       hidden_states=hidden_states,
                                       kv_cache=kv_cache,
                                       attn_metadata=attn_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
296
297

        # Fully Connected
298
299
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
300
        hidden_states = self.mlp(hidden_states)
301
        return hidden_states, residual
Woosuk Kwon's avatar
Woosuk Kwon committed
302
303


304
@support_torch_compile
Woosuk Kwon's avatar
Woosuk Kwon committed
305
306
class LlamaModel(nn.Module):

307
308
309
310
311
    def __init__(self,
                 *,
                 vllm_config: VllmConfig,
                 prefix: str = "",
                 layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer):
Woosuk Kwon's avatar
Woosuk Kwon committed
312
        super().__init__()
313
314
315
316
317
318

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config

Woosuk Kwon's avatar
Woosuk Kwon committed
319
        self.config = config
320
        self.quant_config = quant_config
Woosuk Kwon's avatar
Woosuk Kwon committed
321
        self.padding_idx = config.pad_token_id
322
323
324
325
        lora_vocab = (lora_config.lora_extra_vocab_size *
                      (lora_config.max_loras or 1)) if lora_config else 0
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size
326
327
328
329
330
331
        if get_pp_group().is_first_rank or (config.tie_word_embeddings
                                            and get_pp_group().is_last_rank):
            self.embed_tokens = VocabParallelEmbedding(
                self.vocab_size,
                config.hidden_size,
                org_num_embeddings=config.vocab_size,
332
                quant_config=quant_config,
333
334
335
            )
        else:
            self.embed_tokens = PPMissingLayer()
336
        self.start_layer, self.end_layer, self.layers = make_layers(
337
            config.num_hidden_layers,
338
339
340
341
            lambda prefix: layer_type(config=config,
                                      cache_config=cache_config,
                                      quant_config=quant_config,
                                      prefix=prefix),
342
343
            prefix=f"{prefix}.layers",
        )
344
345
346
347
        if get_pp_group().is_last_rank:
            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        else:
            self.norm = PPMissingLayer()
Woosuk Kwon's avatar
Woosuk Kwon committed
348

349
350
351
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
zhuwenwen's avatar
zhuwenwen committed
352
353
354
355
356
357
358
359
360
361
362
        
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
              
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
        # self.use_lm_nn = os.environ.get('LM_NN') == '1'
        self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
        self.use_fa_pad = os.environ.get('FA_PAD') == '1'
        self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
zhuwenwen's avatar
zhuwenwen committed
363
        self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
364

365
366
367
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

Woosuk Kwon's avatar
Woosuk Kwon committed
368
369
    def forward(
        self,
370
        input_ids: Optional[torch.Tensor],
371
        positions: torch.Tensor,
372
373
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
374
        intermediate_tensors: Optional[IntermediateTensors],
375
        inputs_embeds: Optional[torch.Tensor] = None,
376
377
378
379
380
381
382
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
            residual = None
383
        else:
384
385
386
387
388
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

        for i in range(self.start_layer, self.end_layer):
Woosuk Kwon's avatar
Woosuk Kwon committed
389
            layer = self.layers[i]
390
391
392
            hidden_states, residual = layer(positions, hidden_states,
                                            kv_caches[i - self.start_layer],
                                            attn_metadata, residual)
393
394
395
396
397
398
399

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })

400
        hidden_states, _ = self.norm(hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
401
402
        return hidden_states

403
404
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
405
406
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
407
408
409
410
411
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
Zhuohan Li's avatar
Zhuohan Li committed
412
        ]
413
        params_dict = dict(self.named_parameters())
414
        loaded_params: Set[str] = set()
415
        for name, loaded_weight in weights:
416
417
            if "rotary_emb.inv_freq" in name:
                continue
418
419
420
421
            if ("rotary_emb.cos_cached" in name
                    or "rotary_emb.sin_cached" in name):
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
422
                continue
423
424
            if (self.quant_config is not None and
                (scale_name := self.quant_config.get_cache_scale(name))):
425
                # Loading kv cache quantization scales
426
427
428
                param = params_dict[scale_name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
429
430
                loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
                                 loaded_weight[0])
431
                weight_loader(param, loaded_weight)
432
                loaded_params.add(scale_name)
433
                continue
434
435
436
437
438
            if "scale" in name:
                # Remapping the name of FP8 kv-scale.
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue
439
            for param_name, weight_name, shard_id in stacked_params_mapping:
Zhuohan Li's avatar
Zhuohan Li committed
440
                if weight_name not in name:
441
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
442
443
444
445
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
446
447
448
449

                if is_pp_missing_parameter(name, self):
                    continue

CHU Tianxiang's avatar
CHU Tianxiang committed
450
                param = params_dict[name]
451
452
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
453
                break
454
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
455
456
457
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
458
459
460
461

                if is_pp_missing_parameter(name, self):
                    continue

462
463
464
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
465
                weight_loader(param, loaded_weight)
466
            loaded_params.add(name)
gaoqiong's avatar
gaoqiong committed
467
            
468
        if self.use_llama_nn and self.quant_method is None :
gaoqiong's avatar
gaoqiong committed
469
470
471
472
            lay_key_words = [
                "self_attn.qkv_proj.weight",
                "self_attn.o_proj.weight",
                "mlp.gate_up_proj.weight",
473
                "mlp.down_proj.weight",
gaoqiong's avatar
gaoqiong committed
474
475
476
            ]
            combined_words = "|".join(lay_key_words)
            
zhuwenwen's avatar
zhuwenwen committed
477
478
            # lay_qkv_words = ["self_attn.qkv_proj.weight"]   
            # qkv_words = "|".join(lay_qkv_words)          
zhuwenwen's avatar
zhuwenwen committed
479
            
gaoqiong's avatar
gaoqiong committed
480
            for layername, weight in params_dict.items():
zhuwenwen's avatar
zhuwenwen committed
481
                if "lm_head.weight" in layername and weight.shape[1] >= 4096:
482
483
                    lay_key_words.append("lm_head.weight")
                    combined_words = "|".join(lay_key_words)
484
485
486
                    os.environ['LM_NN'] = '1'  
                else:
                    os.environ['LM_NN'] = '0' 
487
                    
gaoqiong's avatar
gaoqiong committed
488
                matches = re.findall(combined_words, layername)
489
                
490
                if matches:         
zhuwenwen's avatar
zhuwenwen committed
491
492
                    # if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
                    #     weight.data = pad_weight(weight.data, 32)  
493
                        
zhuwenwen's avatar
zhuwenwen committed
494
495
496
                    # if self.use_fa_pad and (re.findall(qkv_words, layername)):
                    #     if not gemm_bank_conf(weight.data.shape[0]):
                    #         weight.data = pad_weight(weight.data, 32)
497
                                 
gaoqiong's avatar
gaoqiong committed
498
499
500
                    _weight = torch.zeros_like(weight.data)
                    ori_shape =_weight.shape
                    
zhuwenwen's avatar
zhuwenwen committed
501
                    ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
gaoqiong's avatar
gaoqiong committed
502
503
                    weight.data.copy_(_weight)
                    
zhuwenwen's avatar
zhuwenwen committed
504
                    weight.data=weight.data.reshape(ori_shape[1], -1)
gaoqiong's avatar
gaoqiong committed
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
     
        if self.quant_method == "awq":
            lay_key_words = [
                "self_attn.qkv_proj.qweight",
                "self_attn.o_proj.qweight",
                "mlp.gate_up_proj.qweight",
                "mlp.down_proj.qweight"
            ]
            combined_words = "|".join(lay_key_words)
            
            for layername, weight in params_dict.items():
                
                matches = re.findall(combined_words, layername)
                if matches:
                    qweight =params_dict[layername]
                    qzeros=params_dict[layername.replace("qweight", "qzeros")]
                    scales=params_dict[layername.replace("qweight", "scales")]
                    zeros_and_scalse =params_dict[layername.replace("qweight", "zeros_and_scales")]
                    
                    group_size= self.quant_config.group_size 
                   
                    dim_n = scales.data.shape[1]
                    dim_k = qweight.data.shape[0]
                    pad_group=2              
                    
gaoqiong's avatar
gaoqiong committed
530
                    _qw, _sz=ops.convert_s4(qweight,qzeros,scales,int(group_size)) 
gaoqiong's avatar
gaoqiong committed
531
                    
gaoqiong's avatar
gaoqiong committed
532
                    sz = ops.sz_permute(_sz).reshape(-1,dim_n)       
gaoqiong's avatar
gaoqiong committed
533
534
535
                    
                    zeros_and_scalse.data.copy_(sz)
                    qweight.data.copy_(_qw)
gaoqiong's avatar
gaoqiong committed
536
                    
gaoqiong's avatar
gaoqiong committed
537
538
539
540
                    #reshape
                    zeros_and_scalse.data=zeros_and_scalse.reshape(dim_n,-1)    #[k/greop_size,n]------>[n,k/group_size]
                    qweight.data=qweight.data.reshape(dim_n,-1)                      #[k,n/8]---->[n,k/8]  
                
zhuwenwen's avatar
zhuwenwen committed
541
                    if dim_k % 4096==0 and self.use_awq_pad:
gaoqiong's avatar
gaoqiong committed
542
543
544
545
                        zeros_and_scalse_pad= torch.zeros(dim_n,pad_group,dtype=torch.int32).cuda()
                        zeros_and_scalse.data=torch.cat((zeros_and_scalse.data,zeros_and_scalse_pad),dim=1).contiguous()
                        qweight_pad= torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda()
                        qweight.data=torch.cat((qweight.data,qweight_pad),dim=1).contiguous()
zhuwenwen's avatar
zhuwenwen committed
546
            
gaoqiong's avatar
gaoqiong committed
547
        #当为triton支持推理的时候不能进行处理
zhuwenwen's avatar
zhuwenwen committed
548
549
550
551
552
553
554
555
        if self.quant_method == "compressed_tensors":
            lay_key_words = [
                "self_attn.qkv_proj.weight",
                "self_attn.o_proj.weight",
                "mlp.gate_up_proj.weight",
                "mlp.down_proj.weight",
            ]
            combined_words = "|".join(lay_key_words)
gaoqiong's avatar
gaoqiong committed
556
557
            weight_shapes=[]
            all_json={}
zhuwenwen's avatar
zhuwenwen committed
558
559
560
            
            for layername, weight in params_dict.items():  
                matches = re.findall(combined_words, layername)
gaoqiong's avatar
gaoqiong committed
561
                if matches and "scale" not in layername:
zhuwenwen's avatar
zhuwenwen committed
562
                    weight_data =params_dict[layername]
gaoqiong's avatar
gaoqiong committed
563
                    n=weight_data.shape[0]
zhuwenwen's avatar
zhuwenwen committed
564
565
566
567
                    # k=weight_data.shape[1]
                    
                    # #判断当前size是否在优化的范围内,假如存在则走triton,假如不存在则走rocblas
                    # json_file=self.tritonsingleton.get_w8a8json_name(n,k)
gaoqiong's avatar
gaoqiong committed
568
569
570
571
572
573
574
575
576
577
578
579
580
                    
                    #rocblas和cutlass目前都需要weight做处理,但是triton不用
                    if self.w8a8_strategy!=1:
                        _weight=weight_data.T.contiguous().reshape(n,-1)
                        weight_data.data.copy_(_weight)  
                    
                    #下面是针对模型记录模型出现k和n值 
                    elif len(weight_shapes)<4: 
                        k=weight_data.shape[1]
                        weight_shapes.append({n,k})
                
                        json_file=self.tritonsingleton.get_w8a8json_name(n,k)
                        configs_dict=self.tritonsingleton.get_triton_cache(json_file,n,k)
gaoqiong's avatar
gaoqiong committed
581
582
                        if configs_dict:
                            all_json.update(configs_dict)
gaoqiong's avatar
gaoqiong committed
583
584
585
586
587
588
589
590
                                              
            if self.w8a8_strategy==1:
                self.tritonsingleton.triton_json_dict.append(all_json)
                #找到的所有config都进行一次warmup
                for key, value in all_json.items():
                    m=int(key.split('_')[0])
                    n=int(key.split('_')[1])
                    k=int(key.split('_')[2])
gaoqiong's avatar
gaoqiong committed
591
                    ops.triton_int8_gemm_helper(m=m,n=n,k=k,per_token_act_quant=True,per_out_channel_weight_quant=True,use_bias=False,best_config=value)
zhuwenwen's avatar
zhuwenwen committed
592
                    
593
        return loaded_params
594

Woosuk Kwon's avatar
Woosuk Kwon committed
595

596
class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
Terry's avatar
Terry committed
597
    packed_modules_mapping = {
598
599
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"]
Terry's avatar
Terry committed
600
601
602
603
    }

    # LoRA specific attributes
    supported_lora_modules = [
604
605
        "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens",
        "lm_head"
Terry's avatar
Terry committed
606
607
608
    ]
    embedding_modules = {
        "embed_tokens": "input_embeddings",
609
        "lm_head": "output_embeddings"
Terry's avatar
Terry committed
610
611
    }
    embedding_padding_modules = ["lm_head"]
612

613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
    # Mistral/Llama models can also be loaded with --load-format mistral
    # from consolidated.safetensors checkpoints
    mistral_mapping = {
        "layers": "model.layers",
        "attention": "self_attn",
        "wq": "q_proj",
        "wk": "k_proj",
        "wv": "v_proj",
        "wo": "o_proj",
        "attention_norm": "input_layernorm",
        "feed_forward": "mlp",
        "w1": "gate_proj",
        "w2": "down_proj",
        "w3": "up_proj",
        "ffn_norm": "post_attention_layernorm",
        "tok_embeddings": "model.embed_tokens",
        "output": "lm_head",
        "norm": "model.norm"
    }
632

633
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Woosuk Kwon's avatar
Woosuk Kwon committed
634
        super().__init__()
635
636
637
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
Woosuk Kwon's avatar
Woosuk Kwon committed
638
        self.config = config
639
        self.lora_config = lora_config
zhuwenwen's avatar
zhuwenwen committed
640
        
641
642
        self.model = self._init_model(vllm_config=vllm_config,
                                      prefix=maybe_prefix(prefix, "model"))
zhuwenwen's avatar
zhuwenwen committed
643
644
        
        self.tritonsingleton= W8a8GetCacheJSON()
645

646
647
648
649
650
651
652
653
        if get_pp_group().is_last_rank:
            self.unpadded_vocab_size = config.vocab_size
            if lora_config:
                self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
            self.lm_head = ParallelLMHead(
                self.unpadded_vocab_size,
                config.hidden_size,
                org_num_embeddings=config.vocab_size,
654
655
656
657
658
659
                padding_size=(
                    DEFAULT_VOCAB_PADDING_SIZE
                    # We need bigger padding if using lora for kernel
                    # compatibility
                    if not lora_config else
                    lora_config.lora_vocab_padding_size),
660
                quant_config=quant_config,
661
                prefix=maybe_prefix(prefix, "lm_head"),
662
663
            )
            if config.tie_word_embeddings:
664
665
                self.lm_head = self.lm_head.tie_weights(
                    self.model.embed_tokens)
666
667
668
669
670
671
672

            logit_scale = getattr(config, "logit_scale", 1.0)
            self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                    config.vocab_size,
                                                    logit_scale)
        else:
            self.lm_head = PPMissingLayer()
673

674
675
        self.sampler = get_sampler()

676
677
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
zhuwenwen's avatar
zhuwenwen committed
678
        
Woosuk Kwon's avatar
Woosuk Kwon committed
679

680
681
682
    def _init_model(self, vllm_config: VllmConfig, prefix: str = ""):
        return LlamaModel(vllm_config=vllm_config, prefix=prefix)

683
684
685
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

Woosuk Kwon's avatar
Woosuk Kwon committed
686
687
    def forward(
        self,
688
689
        input_ids: torch.Tensor,
        positions: torch.Tensor,
690
691
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
692
        intermediate_tensors: Optional[IntermediateTensors] = None,
693
        inputs_embeds: Optional[torch.Tensor] = None,
694
695
    ) -> Union[torch.Tensor, IntermediateTensors]:
        model_output = self.model(input_ids, positions, kv_caches,
696
697
                                  attn_metadata, intermediate_tensors,
                                  inputs_embeds)
698
        return model_output
699

700
701
702
703
704
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
705
        logits = self.logits_processor(self.lm_head, hidden_states,
706
707
708
                                       sampling_metadata)
        return logits

709
710
    def sample(self, logits: torch.Tensor,
               sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
711
        next_tokens = self.sampler(logits, sampling_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
712
713
        return next_tokens

zhuwenwen's avatar
zhuwenwen committed
714

715
716
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
717
718
719
720
721
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=(["lm_head."]
                           if self.config.tie_word_embeddings else None),
        )
722
        return loader.load_weights(
723
            self.maybe_remap_mistral(name, loaded_weight)
724
            for name, loaded_weight in weights)
725

zhuwenwen's avatar
zhuwenwen committed
726

727
728
729
    # This function is used to remap the mistral format as
    # used by Mistral and Llama <=2
    def maybe_remap_mistral(
730
731
732
733
        self,
        name: str,
        loaded_weight: torch.Tensor,
    ) -> Tuple[str, torch.Tensor]:
734

735
        def permute(w: torch.Tensor, n_heads: int):
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
            attn_in = self.config.head_dim * n_heads
            attn_out = self.config.hidden_size

            return w.view(n_heads, attn_in // n_heads // 2, 2,
                          attn_out).transpose(1, 2).reshape(attn_in, attn_out)

        mapping = self.mistral_mapping
        modules = name.split(".")

        # rotary embeds should be sliced
        if "wk" in modules:
            loaded_weight = permute(loaded_weight,
                                    self.config.num_key_value_heads)
        elif "wq" in modules:
            loaded_weight = permute(loaded_weight,
                                    self.config.num_attention_heads)

        for item in modules:
            if item in mapping and mapping[item] not in name:
                name = name.replace(item, mapping[item])

        return name, loaded_weight