llama.py 32.1 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
Woosuk Kwon's avatar
Woosuk Kwon committed
5
# Copyright 2023 The vLLM team.
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
24
"""Inference-only LLaMA model compatible with HuggingFace weights."""
25
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
26
27
28
29

import torch
from torch import nn
from transformers import LlamaConfig
zhuwenwen's avatar
zhuwenwen committed
30
import os
gaoqiong's avatar
gaoqiong committed
31
import re
32
import vllm.envs as envs
33
from vllm.attention import Attention, AttentionMetadata
34
from vllm.compilation.decorators import support_torch_compile
35
from vllm.config import CacheConfig, VllmConfig
36
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
Woosuk Kwon's avatar
Woosuk Kwon committed
37
from vllm.model_executor.layers.activation import SiluAndMul
38
from vllm.model_executor.layers.layernorm import RMSNorm
39
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
40
41
                                               QKVParallelLinear,
                                               RowParallelLinear)
42
from vllm.model_executor.layers.logits_processor import LogitsProcessor
43
from vllm.model_executor.layers.quantization import QuantizationConfig
44
from vllm.model_executor.layers.rotary_embedding import get_rope
Joe Runde's avatar
Joe Runde committed
45
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
46
from vllm.model_executor.layers.vocab_parallel_embedding import (
47
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
48
from vllm.model_executor.model_loader.weight_utils import (
49
    default_weight_loader, maybe_remap_kv_scale_name)
50
from vllm.model_executor.sampling_metadata import SamplingMetadata
51
from vllm.sequence import IntermediateTensors
zhuwenwen's avatar
zhuwenwen committed
52
from vllm.utils import W8a8GetCacheJSON
Woosuk Kwon's avatar
Woosuk Kwon committed
53

54
from .interfaces import SupportsLoRA, SupportsPP
55
56
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
                    is_pp_missing_parameter,
57
58
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
59

gaoqiong's avatar
gaoqiong committed
60
from vllm import _custom_ops as ops
61
62
from vllm.model_executor.utils import pad_weight, gemm_bank_conf

Woosuk Kwon's avatar
Woosuk Kwon committed
63
64

class LlamaMLP(nn.Module):
65

Woosuk Kwon's avatar
Woosuk Kwon committed
66
67
    def __init__(
        self,
68
69
70
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
71
        quant_config: Optional[QuantizationConfig] = None,
72
        bias: bool = False,
73
        prefix: str = "",
74
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
75
        super().__init__()
76
        self.gate_up_proj = MergedColumnParallelLinear(
77
78
            input_size=hidden_size,
            output_sizes=[intermediate_size] * 2,
79
            bias=bias,
80
            quant_config=quant_config,
81
82
83
84
85
86
87
88
89
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            input_size=intermediate_size,
            output_size=hidden_size,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.down_proj",
        )
90
91
92
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
Woosuk Kwon's avatar
Woosuk Kwon committed
93
        self.act_fn = SiluAndMul()
Woosuk Kwon's avatar
Woosuk Kwon committed
94
95

    def forward(self, x):
96
97
        x, _ = self.gate_up_proj(x)
        x = self.act_fn(x)
Woosuk Kwon's avatar
Woosuk Kwon committed
98
99
100
101
102
103
        x, _ = self.down_proj(x)
        return x


class LlamaAttention(nn.Module):

104
105
106
107
108
109
110
111
112
113
    def __init__(self,
                 config: LlamaConfig,
                 hidden_size: int,
                 num_heads: int,
                 num_kv_heads: int,
                 rope_theta: float = 10000,
                 rope_scaling: Optional[Dict[str, Any]] = None,
                 max_position_embeddings: int = 8192,
                 quant_config: Optional[QuantizationConfig] = None,
                 bias: bool = False,
114
                 bias_o_proj: bool = False,
115
                 cache_config: Optional[CacheConfig] = None,
116
                 prefix: str = "") -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
117
        super().__init__()
118
        layer_idx = extract_layer_index(prefix)
119
        self.hidden_size = hidden_size
Zhuohan Li's avatar
Zhuohan Li committed
120
        tp_size = get_tensor_model_parallel_world_size()
121
        self.total_num_heads = num_heads
Zhuohan Li's avatar
Zhuohan Li committed
122
123
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
124
        self.total_num_kv_heads = num_kv_heads
125
126
127
128
129
130
131
132
133
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
134
135
136
        # MistralConfig has an optional head_dim introduced by Mistral-Nemo
        self.head_dim = getattr(config, "head_dim",
                                self.hidden_size // self.total_num_heads)
Amit Garg's avatar
Amit Garg committed
137
138
139
        # Phi models introduced a partial_rotary_factor parameter in the config
        partial_rotary_factor = getattr(config, "partial_rotary_factor", 1)
        self.rotary_dim = int(partial_rotary_factor * self.head_dim)
Zhuohan Li's avatar
Zhuohan Li committed
140
141
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
142
        self.scaling = self.head_dim**-0.5
143
144
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings
Woosuk Kwon's avatar
Woosuk Kwon committed
145

146
        self.qkv_proj = QKVParallelLinear(
147
148
149
150
            hidden_size=hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_kv_heads,
151
            bias=bias,
152
            quant_config=quant_config,
153
            prefix=f"{prefix}.qkv_proj",
Woosuk Kwon's avatar
Woosuk Kwon committed
154
        )
155

156
        self.o_proj = RowParallelLinear(
157
158
            input_size=self.total_num_heads * self.head_dim,
            output_size=hidden_size,
159
            bias=bias_o_proj,
160
            quant_config=quant_config,
161
            prefix=f"{prefix}.o_proj",
Woosuk Kwon's avatar
Woosuk Kwon committed
162
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
163

164
        is_neox_style = True
165
166
        is_gguf = quant_config and quant_config.get_name() == "gguf"
        if is_gguf and config.model_type == "llama":
167
168
            is_neox_style = False

169
170
        self.rotary_emb = get_rope(
            self.head_dim,
Amit Garg's avatar
Amit Garg committed
171
            rotary_dim=self.rotary_dim,
172
173
174
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
175
            is_neox_style=is_neox_style,
176
        )
177
178

        if hasattr(config, "interleaved_sliding_window"):
179
180
181
182
183
184
            interleaved_sliding_window = config.interleaved_sliding_window
            if isinstance(interleaved_sliding_window, int):
                sliding_window = interleaved_sliding_window
            elif isinstance(interleaved_sliding_window, list):
                sw_idx = layer_idx % len(interleaved_sliding_window)
                sliding_window = interleaved_sliding_window[sw_idx]
185
            else:
186
187
                raise ValueError(
                    f"{type(interleaved_sliding_window)} is not supported.")
188
189
190
        else:
            sliding_window = None

191
192
193
194
195
196
197
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
198
            per_layer_sliding_window=sliding_window,
199
            prefix=f"{prefix}.attn",
200
        )
201
202
203
204
205
        
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
Woosuk Kwon's avatar
Woosuk Kwon committed
206
207
208

    def forward(
        self,
209
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
210
        hidden_states: torch.Tensor,
211
212
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
213
    ) -> torch.Tensor:
214
        qkv, _ = self.qkv_proj(hidden_states)
zhuwenwen's avatar
zhuwenwen committed
215
216
        # if os.environ.get('FA_PAD') == '1' and self.quant_method is None:
        #     qkv = qkv[...,:-32]
Zhuohan Li's avatar
Zhuohan Li committed
217
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
218
        q, k = self.rotary_emb(positions, q, k)
219
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
220
221
222
223
224
225
        output, _ = self.o_proj(attn_output)
        return output


class LlamaDecoderLayer(nn.Module):

226
227
228
    def __init__(
        self,
        config: LlamaConfig,
229
        cache_config: Optional[CacheConfig] = None,
230
        quant_config: Optional[QuantizationConfig] = None,
231
        prefix: str = "",
232
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
233
234
        super().__init__()
        self.hidden_size = config.hidden_size
235
236
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
237
238
239
240
        if rope_scaling is not None and getattr(
                config, "original_max_position_embeddings", None):
            rope_scaling["original_max_position_embeddings"] = (
                config.original_max_position_embeddings)
241
242
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
243
244
245
246
        # Support abacusai/Smaug-72B-v0.1 with attention_bias
        # Support internlm/internlm-7b with bias
        attention_bias = getattr(config, "attention_bias", False) or getattr(
            config, "bias", False)
247
248
249
250
251
        bias_o_proj = attention_bias
        # support internlm/internlm3-8b with qkv_bias
        if hasattr(config, 'qkv_bias'):
            attention_bias = config.qkv_bias

Woosuk Kwon's avatar
Woosuk Kwon committed
252
        self.self_attn = LlamaAttention(
253
            config=config,
254
255
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
256
257
            num_kv_heads=getattr(config, "num_key_value_heads",
                                 config.num_attention_heads),
258
259
260
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
261
            quant_config=quant_config,
262
            bias=attention_bias,
263
            bias_o_proj=bias_o_proj,
264
            cache_config=cache_config,
265
            prefix=f"{prefix}.self_attn",
Woosuk Kwon's avatar
Woosuk Kwon committed
266
267
        )
        self.mlp = LlamaMLP(
268
269
270
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
271
            quant_config=quant_config,
272
            bias=getattr(config, "mlp_bias", False),
273
            prefix=f"{prefix}.mlp",
Woosuk Kwon's avatar
Woosuk Kwon committed
274
        )
275
276
277
278
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)
Woosuk Kwon's avatar
Woosuk Kwon committed
279
280
281

    def forward(
        self,
282
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
283
        hidden_states: torch.Tensor,
284
285
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
286
287
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
288
        # Self Attention
289
290
291
292
293
294
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
295
296
297
298
        hidden_states = self.self_attn(positions=positions,
                                       hidden_states=hidden_states,
                                       kv_cache=kv_cache,
                                       attn_metadata=attn_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
299
300

        # Fully Connected
301
302
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
303
        hidden_states = self.mlp(hidden_states)
304
        return hidden_states, residual
Woosuk Kwon's avatar
Woosuk Kwon committed
305
306


307
@support_torch_compile
Woosuk Kwon's avatar
Woosuk Kwon committed
308
309
class LlamaModel(nn.Module):

310
311
312
313
314
    def __init__(self,
                 *,
                 vllm_config: VllmConfig,
                 prefix: str = "",
                 layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer):
Woosuk Kwon's avatar
Woosuk Kwon committed
315
        super().__init__()
316
317
318
319
320
321

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config

Woosuk Kwon's avatar
Woosuk Kwon committed
322
        self.config = config
323
        self.quant_config = quant_config
Woosuk Kwon's avatar
Woosuk Kwon committed
324
        self.padding_idx = config.pad_token_id
325
326
327
328
        lora_vocab = (lora_config.lora_extra_vocab_size *
                      (lora_config.max_loras or 1)) if lora_config else 0
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size
329
330
331
332
333
334
        if get_pp_group().is_first_rank or (config.tie_word_embeddings
                                            and get_pp_group().is_last_rank):
            self.embed_tokens = VocabParallelEmbedding(
                self.vocab_size,
                config.hidden_size,
                org_num_embeddings=config.vocab_size,
335
                quant_config=quant_config,
336
337
338
            )
        else:
            self.embed_tokens = PPMissingLayer()
339
        self.start_layer, self.end_layer, self.layers = make_layers(
340
            config.num_hidden_layers,
341
342
343
344
            lambda prefix: layer_type(config=config,
                                      cache_config=cache_config,
                                      quant_config=quant_config,
                                      prefix=prefix),
345
346
            prefix=f"{prefix}.layers",
        )
347
348
349
350
        if get_pp_group().is_last_rank:
            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        else:
            self.norm = PPMissingLayer()
Woosuk Kwon's avatar
Woosuk Kwon committed
351

352
353
354
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
zhuwenwen's avatar
zhuwenwen committed
355
356
357
358
359
        
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
gaoqiong's avatar
gaoqiong committed
360
361
            
        self.tritonsingleton= W8a8GetCacheJSON()      
zhuwenwen's avatar
zhuwenwen committed
362
363
364
365
366
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
        # self.use_lm_nn = os.environ.get('LM_NN') == '1'
        self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
        self.use_fa_pad = os.environ.get('FA_PAD') == '1'
        self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
zhuwenwen's avatar
zhuwenwen committed
367
        self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
368

369
370
371
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

Woosuk Kwon's avatar
Woosuk Kwon committed
372
373
    def forward(
        self,
374
        input_ids: Optional[torch.Tensor],
375
        positions: torch.Tensor,
376
377
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
378
        intermediate_tensors: Optional[IntermediateTensors],
379
        inputs_embeds: Optional[torch.Tensor] = None,
380
381
382
383
384
385
386
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
            residual = None
387
        else:
388
389
390
391
392
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

        for i in range(self.start_layer, self.end_layer):
Woosuk Kwon's avatar
Woosuk Kwon committed
393
            layer = self.layers[i]
394
395
396
            hidden_states, residual = layer(positions, hidden_states,
                                            kv_caches[i - self.start_layer],
                                            attn_metadata, residual)
397
398
399
400
401
402
403

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })

404
        hidden_states, _ = self.norm(hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
405
406
        return hidden_states

407
408
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
409
410
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
411
412
413
414
415
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
Zhuohan Li's avatar
Zhuohan Li committed
416
        ]
417
        params_dict = dict(self.named_parameters())
418
        loaded_params: Set[str] = set()
419
        for name, loaded_weight in weights:
420
421
            if "rotary_emb.inv_freq" in name:
                continue
422
423
424
425
            if ("rotary_emb.cos_cached" in name
                    or "rotary_emb.sin_cached" in name):
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
426
                continue
427
428
            if (self.quant_config is not None and
                (scale_name := self.quant_config.get_cache_scale(name))):
429
                # Loading kv cache quantization scales
430
431
432
                param = params_dict[scale_name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
433
434
                loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
                                 loaded_weight[0])
435
                weight_loader(param, loaded_weight)
436
                loaded_params.add(scale_name)
437
                continue
438
439
440
441
442
            if "scale" in name:
                # Remapping the name of FP8 kv-scale.
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue
443
            for param_name, weight_name, shard_id in stacked_params_mapping:
Zhuohan Li's avatar
Zhuohan Li committed
444
                if weight_name not in name:
445
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
446
447
448
449
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
450
451
452
453

                if is_pp_missing_parameter(name, self):
                    continue

CHU Tianxiang's avatar
CHU Tianxiang committed
454
                param = params_dict[name]
455
456
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
457
                break
458
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
459
460
461
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
462
463
464
465

                if is_pp_missing_parameter(name, self):
                    continue

466
467
468
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
469
                weight_loader(param, loaded_weight)
470
            loaded_params.add(name)
gaoqiong's avatar
gaoqiong committed
471
            
472
        if self.use_llama_nn and self.quant_method is None :
gaoqiong's avatar
gaoqiong committed
473
474
475
476
            lay_key_words = [
                "self_attn.qkv_proj.weight",
                "self_attn.o_proj.weight",
                "mlp.gate_up_proj.weight",
477
                "mlp.down_proj.weight",
gaoqiong's avatar
gaoqiong committed
478
479
480
            ]
            combined_words = "|".join(lay_key_words)
            
zhuwenwen's avatar
zhuwenwen committed
481
482
            # lay_qkv_words = ["self_attn.qkv_proj.weight"]   
            # qkv_words = "|".join(lay_qkv_words)          
zhuwenwen's avatar
zhuwenwen committed
483
            
zhuwenwen's avatar
zhuwenwen committed
484
485
486
            # for layername, weight in params_dict.items():
            for layername in loaded_params:
                weight = params_dict[layername]
zhuwenwen's avatar
zhuwenwen committed
487
                if "lm_head.weight" in layername and weight.shape[1] >= 4096:
488
489
                    lay_key_words.append("lm_head.weight")
                    combined_words = "|".join(lay_key_words)
490
491
492
                    os.environ['LM_NN'] = '1'  
                else:
                    os.environ['LM_NN'] = '0' 
493
                    
gaoqiong's avatar
gaoqiong committed
494
                matches = re.findall(combined_words, layername)
495
                
496
                if matches:         
zhuwenwen's avatar
zhuwenwen committed
497
498
                    # if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
                    #     weight.data = pad_weight(weight.data, 32)  
499
                        
zhuwenwen's avatar
zhuwenwen committed
500
501
502
                    # if self.use_fa_pad and (re.findall(qkv_words, layername)):
                    #     if not gemm_bank_conf(weight.data.shape[0]):
                    #         weight.data = pad_weight(weight.data, 32)
503
                                 
gaoqiong's avatar
gaoqiong committed
504
505
506
                    _weight = torch.zeros_like(weight.data)
                    ori_shape =_weight.shape
                    
zhuwenwen's avatar
zhuwenwen committed
507
                    ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
gaoqiong's avatar
gaoqiong committed
508
509
                    weight.data.copy_(_weight)
                    
zhuwenwen's avatar
zhuwenwen committed
510
                    weight.data=weight.data.reshape(ori_shape[1], -1)
511
        else:
zhuwenwen's avatar
zhuwenwen committed
512
            os.environ['LM_NN'] = '0'
513
514
515
            os.environ['LLAMA_NN'] = '0'
            
        if self.quant_method == "awq" and not envs.VLLM_USE_TRITON_AWQ:
gaoqiong's avatar
gaoqiong committed
516
517
518
519
520
521
522
523
            lay_key_words = [
                "self_attn.qkv_proj.qweight",
                "self_attn.o_proj.qweight",
                "mlp.gate_up_proj.qweight",
                "mlp.down_proj.qweight"
            ]
            combined_words = "|".join(lay_key_words)
            
zhuwenwen's avatar
zhuwenwen committed
524
525
            for layername in loaded_params:
                weight = params_dict[layername]
gaoqiong's avatar
gaoqiong committed
526
527
528
529
530
531
532
533
534
535
536
537
538
539
                
                matches = re.findall(combined_words, layername)
                if matches:
                    qweight =params_dict[layername]
                    qzeros=params_dict[layername.replace("qweight", "qzeros")]
                    scales=params_dict[layername.replace("qweight", "scales")]
                    zeros_and_scalse =params_dict[layername.replace("qweight", "zeros_and_scales")]
                    
                    group_size= self.quant_config.group_size 
                   
                    dim_n = scales.data.shape[1]
                    dim_k = qweight.data.shape[0]
                    pad_group=2              
                    
gaoqiong's avatar
gaoqiong committed
540
                    _qw, _sz=ops.convert_s4(qweight,qzeros,scales,int(group_size)) 
gaoqiong's avatar
gaoqiong committed
541
                    
gaoqiong's avatar
gaoqiong committed
542
                    sz = ops.sz_permute(_sz).reshape(-1,dim_n)       
gaoqiong's avatar
gaoqiong committed
543
544
545
                    
                    zeros_and_scalse.data.copy_(sz)
                    qweight.data.copy_(_qw)
gaoqiong's avatar
gaoqiong committed
546
                    
gaoqiong's avatar
gaoqiong committed
547
548
549
550
                    #reshape
                    zeros_and_scalse.data=zeros_and_scalse.reshape(dim_n,-1)    #[k/greop_size,n]------>[n,k/group_size]
                    qweight.data=qweight.data.reshape(dim_n,-1)                      #[k,n/8]---->[n,k/8]  
                
zhuwenwen's avatar
zhuwenwen committed
551
                    if dim_k % 4096==0 and self.use_awq_pad:
gaoqiong's avatar
gaoqiong committed
552
553
554
555
                        zeros_and_scalse_pad= torch.zeros(dim_n,pad_group,dtype=torch.int32).cuda()
                        zeros_and_scalse.data=torch.cat((zeros_and_scalse.data,zeros_and_scalse_pad),dim=1).contiguous()
                        qweight_pad= torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda()
                        qweight.data=torch.cat((qweight.data,qweight_pad),dim=1).contiguous()
zhuwenwen's avatar
zhuwenwen committed
556
            
gaoqiong's avatar
gaoqiong committed
557
        #当为triton支持推理的时候不能进行处理
zhuwenwen's avatar
zhuwenwen committed
558
559
560
561
562
563
564
565
        if self.quant_method == "compressed_tensors":
            lay_key_words = [
                "self_attn.qkv_proj.weight",
                "self_attn.o_proj.weight",
                "mlp.gate_up_proj.weight",
                "mlp.down_proj.weight",
            ]
            combined_words = "|".join(lay_key_words)
gaoqiong's avatar
gaoqiong committed
566
567
            weight_shapes=[]
            all_json={}
568
            matched_key_words=set()
zhuwenwen's avatar
zhuwenwen committed
569
570
571
            
            for layername, weight in params_dict.items():  
                matches = re.findall(combined_words, layername)
gaoqiong's avatar
gaoqiong committed
572
                if matches and "scale" not in layername:
zhuwenwen's avatar
zhuwenwen committed
573
                    weight_data =params_dict[layername]
gaoqiong's avatar
gaoqiong committed
574
                    n=weight_data.shape[0]
zhuwenwen's avatar
zhuwenwen committed
575
576
577
578
                    # k=weight_data.shape[1]
                    
                    # #判断当前size是否在优化的范围内,假如存在则走triton,假如不存在则走rocblas
                    # json_file=self.tritonsingleton.get_w8a8json_name(n,k)
gaoqiong's avatar
gaoqiong committed
579
580
581
582
583
584
585
                    
                    #rocblas和cutlass目前都需要weight做处理,但是triton不用
                    if self.w8a8_strategy!=1:
                        _weight=weight_data.T.contiguous().reshape(n,-1)
                        weight_data.data.copy_(_weight)  
                    
                    #下面是针对模型记录模型出现k和n值 
586
587
                    elif len(matched_key_words) < 4 and matches[0] not in matched_key_words:
                        matched_key_words.add(matches[0])
gaoqiong's avatar
gaoqiong committed
588
589
590
591
592
                        k=weight_data.shape[1]
                        weight_shapes.append({n,k})
                
                        json_file=self.tritonsingleton.get_w8a8json_name(n,k)
                        configs_dict=self.tritonsingleton.get_triton_cache(json_file,n,k)
gaoqiong's avatar
gaoqiong committed
593
594
                        if configs_dict:
                            all_json.update(configs_dict)
gaoqiong's avatar
gaoqiong committed
595
596
597
598
599
600
601
602
                                              
            if self.w8a8_strategy==1:
                self.tritonsingleton.triton_json_dict.append(all_json)
                #找到的所有config都进行一次warmup
                for key, value in all_json.items():
                    m=int(key.split('_')[0])
                    n=int(key.split('_')[1])
                    k=int(key.split('_')[2])
gaoqiong's avatar
gaoqiong committed
603
                    ops.triton_int8_gemm_helper(m=m,n=n,k=k,per_token_act_quant=True,per_out_channel_weight_quant=True,use_bias=False,best_config=value)
zhuwenwen's avatar
zhuwenwen committed
604
                    
605
        return loaded_params
606

Woosuk Kwon's avatar
Woosuk Kwon committed
607

608
class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
Terry's avatar
Terry committed
609
    packed_modules_mapping = {
610
611
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"]
Terry's avatar
Terry committed
612
613
614
615
    }

    # LoRA specific attributes
    supported_lora_modules = [
616
617
        "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens",
        "lm_head"
Terry's avatar
Terry committed
618
619
620
    ]
    embedding_modules = {
        "embed_tokens": "input_embeddings",
621
        "lm_head": "output_embeddings"
Terry's avatar
Terry committed
622
623
    }
    embedding_padding_modules = ["lm_head"]
624

625
626
627
628
629
    # Mistral/Llama models can also be loaded with --load-format mistral
    # from consolidated.safetensors checkpoints
    mistral_mapping = {
        "layers": "model.layers",
        "attention": "self_attn",
630
631
632
        "qscale_act": "input_scale",
        "qscale_weight": "weight_scale",
        "kv_fake_quantizer.qscale_act": "kv_scale",
633
634
635
636
637
638
639
640
641
642
643
644
645
646
        "wq": "q_proj",
        "wk": "k_proj",
        "wv": "v_proj",
        "wo": "o_proj",
        "attention_norm": "input_layernorm",
        "feed_forward": "mlp",
        "w1": "gate_proj",
        "w2": "down_proj",
        "w3": "up_proj",
        "ffn_norm": "post_attention_layernorm",
        "tok_embeddings": "model.embed_tokens",
        "output": "lm_head",
        "norm": "model.norm"
    }
647

648
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Woosuk Kwon's avatar
Woosuk Kwon committed
649
        super().__init__()
650
651
652
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
Woosuk Kwon's avatar
Woosuk Kwon committed
653
        self.config = config
654
        self.lora_config = lora_config
zhuwenwen's avatar
zhuwenwen committed
655
        
656
657
        self.model = self._init_model(vllm_config=vllm_config,
                                      prefix=maybe_prefix(prefix, "model"))
zhuwenwen's avatar
zhuwenwen committed
658
        
gaoqiong's avatar
gaoqiong committed
659
        
660

661
662
663
664
665
666
667
668
        if get_pp_group().is_last_rank:
            self.unpadded_vocab_size = config.vocab_size
            if lora_config:
                self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
            self.lm_head = ParallelLMHead(
                self.unpadded_vocab_size,
                config.hidden_size,
                org_num_embeddings=config.vocab_size,
669
670
671
672
673
674
                padding_size=(
                    DEFAULT_VOCAB_PADDING_SIZE
                    # We need bigger padding if using lora for kernel
                    # compatibility
                    if not lora_config else
                    lora_config.lora_vocab_padding_size),
675
                quant_config=quant_config,
676
                prefix=maybe_prefix(prefix, "lm_head"),
677
678
            )
            if config.tie_word_embeddings:
679
680
                self.lm_head = self.lm_head.tie_weights(
                    self.model.embed_tokens)
681
682
683
684
685
686
687

            logit_scale = getattr(config, "logit_scale", 1.0)
            self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                    config.vocab_size,
                                                    logit_scale)
        else:
            self.lm_head = PPMissingLayer()
688

689
690
        self.sampler = get_sampler()

691
692
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
zhuwenwen's avatar
zhuwenwen committed
693
        
Woosuk Kwon's avatar
Woosuk Kwon committed
694

695
696
697
    def _init_model(self, vllm_config: VllmConfig, prefix: str = ""):
        return LlamaModel(vllm_config=vllm_config, prefix=prefix)

698
699
700
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

Woosuk Kwon's avatar
Woosuk Kwon committed
701
702
    def forward(
        self,
703
704
        input_ids: torch.Tensor,
        positions: torch.Tensor,
705
706
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
707
        intermediate_tensors: Optional[IntermediateTensors] = None,
708
        inputs_embeds: Optional[torch.Tensor] = None,
709
710
    ) -> Union[torch.Tensor, IntermediateTensors]:
        model_output = self.model(input_ids, positions, kv_caches,
711
712
                                  attn_metadata, intermediate_tensors,
                                  inputs_embeds)
713
        return model_output
714

715
716
717
718
719
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
720
        logits = self.logits_processor(self.lm_head, hidden_states,
721
722
723
                                       sampling_metadata)
        return logits

724
725
    def sample(self, logits: torch.Tensor,
               sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
726
        next_tokens = self.sampler(logits, sampling_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
727
728
        return next_tokens

zhuwenwen's avatar
zhuwenwen committed
729

730
731
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
732
733
734
735
736
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=(["lm_head."]
                           if self.config.tie_word_embeddings else None),
        )
737
        return loader.load_weights(
738
            self.maybe_remap_mistral(name, loaded_weight)
739
            for name, loaded_weight in weights)
740

zhuwenwen's avatar
zhuwenwen committed
741

742
743
744
    # This function is used to remap the mistral format as
    # used by Mistral and Llama <=2
    def maybe_remap_mistral(
745
746
747
748
        self,
        name: str,
        loaded_weight: torch.Tensor,
    ) -> Tuple[str, torch.Tensor]:
749

750
        def permute(w: torch.Tensor, n_heads: int):
751
752
753
754
755
756
757
758
759
760
            attn_in = self.config.head_dim * n_heads
            attn_out = self.config.hidden_size

            return w.view(n_heads, attn_in // n_heads // 2, 2,
                          attn_out).transpose(1, 2).reshape(attn_in, attn_out)

        mapping = self.mistral_mapping
        modules = name.split(".")

        # rotary embeds should be sliced
761
        if "wk" in modules and modules[-1] == "weight":
762
763
            loaded_weight = permute(loaded_weight,
                                    self.config.num_key_value_heads)
764
        elif "wq" in modules and modules[-1] == "weight":
765
766
767
            loaded_weight = permute(loaded_weight,
                                    self.config.num_attention_heads)

768
769
770
771
772
773
774
775
776
777
778
        num_modules = len(modules)
        for i in range(num_modules):
            item = modules[i]
            next_item = modules[i + 1] if i < num_modules - 1 else None

            combined_item = (f"{item}.{next_item}"
                             if next_item is not None else None)

            if combined_item in mapping:
                name = name.replace(combined_item, mapping[combined_item])
            elif item in mapping and mapping[item] not in name:
779
780
781
                name = name.replace(item, mapping[item])

        return name, loaded_weight