llama.py 33 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
Woosuk Kwon's avatar
Woosuk Kwon committed
5
# Copyright 2023 The vLLM team.
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
24
"""Inference-only LLaMA model compatible with HuggingFace weights."""
25
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
26
27
28
29

import torch
from torch import nn
from transformers import LlamaConfig
zhuwenwen's avatar
zhuwenwen committed
30
import os
gaoqiong's avatar
gaoqiong committed
31
import re
32
import vllm.envs as envs
33
from vllm.attention import Attention, AttentionMetadata
34
from vllm.compilation.decorators import support_torch_compile
35
from vllm.config import CacheConfig, VllmConfig
36
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank, tensor_model_parallel_all_gather
Woosuk Kwon's avatar
Woosuk Kwon committed
37
from vllm.model_executor.layers.activation import SiluAndMul
38
from vllm.model_executor.layers.layernorm import RMSNorm
39
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
40
41
                                               QKVParallelLinear,
                                               RowParallelLinear)
42
from vllm.model_executor.layers.logits_processor import LogitsProcessor
43
from vllm.model_executor.layers.quantization import QuantizationConfig
44
from vllm.model_executor.layers.rotary_embedding import get_rope
Joe Runde's avatar
Joe Runde committed
45
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
46
from vllm.model_executor.layers.vocab_parallel_embedding import (
47
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
48
from vllm.model_executor.model_loader.weight_utils import (
49
    default_weight_loader, maybe_remap_kv_scale_name)
50
from vllm.model_executor.sampling_metadata import SamplingMetadata
51
from vllm.sequence import IntermediateTensors
zhuwenwen's avatar
zhuwenwen committed
52
from vllm.utils import W8a8GetCacheJSON
Woosuk Kwon's avatar
Woosuk Kwon committed
53

54
from .interfaces import SupportsLoRA, SupportsPP
55
56
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
                    is_pp_missing_parameter,
57
58
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
59

gaoqiong's avatar
gaoqiong committed
60
from vllm import _custom_ops as ops
61
62
from vllm.model_executor.utils import pad_weight, gemm_bank_conf

Woosuk Kwon's avatar
Woosuk Kwon committed
63
64

class LlamaMLP(nn.Module):
65

Woosuk Kwon's avatar
Woosuk Kwon committed
66
67
    def __init__(
        self,
68
69
70
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
71
        quant_config: Optional[QuantizationConfig] = None,
72
        bias: bool = False,
73
        prefix: str = "",
74
        last_layer: bool = False,
75
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
76
        super().__init__()
77
        self.gate_up_proj = MergedColumnParallelLinear(
78
79
            input_size=hidden_size,
            output_sizes=[intermediate_size] * 2,
80
            bias=bias,
81
            quant_config=quant_config,
82
            prefix=f"{prefix}.gate_up_proj",
83
            fuse_ag_gemm=True,
84
85
86
87
88
89
90
        )
        self.down_proj = RowParallelLinear(
            input_size=intermediate_size,
            output_size=hidden_size,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.down_proj",
91
            fuse_gemm_rs=(not last_layer),
92
        )
93
94
95
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
Woosuk Kwon's avatar
Woosuk Kwon committed
96
        self.act_fn = SiluAndMul()
Woosuk Kwon's avatar
Woosuk Kwon committed
97
98

    def forward(self, x):
99
100
        x, _ = self.gate_up_proj(x)
        x = self.act_fn(x)
Woosuk Kwon's avatar
Woosuk Kwon committed
101
102
103
104
105
106
        x, _ = self.down_proj(x)
        return x


class LlamaAttention(nn.Module):

107
108
109
110
111
    def __init__(self,
                 config: LlamaConfig,
                 hidden_size: int,
                 num_heads: int,
                 num_kv_heads: int,
112
                 first_layer: bool,
113
114
115
116
117
                 rope_theta: float = 10000,
                 rope_scaling: Optional[Dict[str, Any]] = None,
                 max_position_embeddings: int = 8192,
                 quant_config: Optional[QuantizationConfig] = None,
                 bias: bool = False,
118
                 bias_o_proj: bool = False,
119
                 cache_config: Optional[CacheConfig] = None,
120
                 prefix: str = "") -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
121
        super().__init__()
122
        layer_idx = extract_layer_index(prefix)
123
        self.hidden_size = hidden_size
Zhuohan Li's avatar
Zhuohan Li committed
124
        tp_size = get_tensor_model_parallel_world_size()
125
        self.total_num_heads = num_heads
Zhuohan Li's avatar
Zhuohan Li committed
126
127
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
128
        self.total_num_kv_heads = num_kv_heads
129
130
131
132
133
134
135
136
137
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
138
139
140
        # MistralConfig has an optional head_dim introduced by Mistral-Nemo
        self.head_dim = getattr(config, "head_dim",
                                self.hidden_size // self.total_num_heads)
Zhuohan Li's avatar
Zhuohan Li committed
141
142
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
143
        self.scaling = self.head_dim**-0.5
144
145
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings
Woosuk Kwon's avatar
Woosuk Kwon committed
146

147
        self.qkv_proj = QKVParallelLinear(
148
149
150
151
            hidden_size=hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            total_num_kv_heads=self.total_num_kv_heads,
152
            bias=bias,
153
            quant_config=quant_config,
154
            prefix=f"{prefix}.qkv_proj",
155
            fuse_ag_gemm=(not first_layer),
Woosuk Kwon's avatar
Woosuk Kwon committed
156
        )
157

158
        self.o_proj = RowParallelLinear(
159
160
            input_size=self.total_num_heads * self.head_dim,
            output_size=hidden_size,
161
            bias=bias_o_proj,
162
            quant_config=quant_config,
163
            prefix=f"{prefix}.o_proj",
164
            fuse_gemm_rs=True,
Woosuk Kwon's avatar
Woosuk Kwon committed
165
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
166

167
        is_neox_style = True
168
169
        is_gguf = quant_config and quant_config.get_name() == "gguf"
        if is_gguf and config.model_type == "llama":
170
171
            is_neox_style = False

172
173
174
175
176
177
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
178
            is_neox_style=is_neox_style,
179
        )
180
181

        if hasattr(config, "interleaved_sliding_window"):
182
183
184
185
186
187
            interleaved_sliding_window = config.interleaved_sliding_window
            if isinstance(interleaved_sliding_window, int):
                sliding_window = interleaved_sliding_window
            elif isinstance(interleaved_sliding_window, list):
                sw_idx = layer_idx % len(interleaved_sliding_window)
                sliding_window = interleaved_sliding_window[sw_idx]
188
            else:
189
190
                raise ValueError(
                    f"{type(interleaved_sliding_window)} is not supported.")
191
192
193
        else:
            sliding_window = None

194
195
196
197
198
199
200
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
201
            per_layer_sliding_window=sliding_window,
202
            prefix=f"{prefix}.attn",
203
        )
204
205
206
207
208
        
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
Woosuk Kwon's avatar
Woosuk Kwon committed
209
210
211

    def forward(
        self,
212
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
213
        hidden_states: torch.Tensor,
214
215
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
216
    ) -> torch.Tensor:
217
        qkv, _ = self.qkv_proj(hidden_states)
zhuwenwen's avatar
zhuwenwen committed
218
219
        # if os.environ.get('FA_PAD') == '1' and self.quant_method is None:
        #     qkv = qkv[...,:-32]
Zhuohan Li's avatar
Zhuohan Li committed
220
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
221
        q, k = self.rotary_emb(positions, q, k)
222
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
223
224
225
226
227
228
        output, _ = self.o_proj(attn_output)
        return output


class LlamaDecoderLayer(nn.Module):

229
230
231
    def __init__(
        self,
        config: LlamaConfig,
232
233
234
235
236
        # Hack: pass in whether this is the first/last layer
        # so we know if we can rewrite AllReduce -> ReduceScatter + AllGather,
        # and then propagate the AllGather to the next layer.
        first_layer: bool,
        last_layer: bool,
237
        cache_config: Optional[CacheConfig] = None,
238
        quant_config: Optional[QuantizationConfig] = None,
239
        prefix: str = "",
240
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
241
242
        super().__init__()
        self.hidden_size = config.hidden_size
243
244
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
245
246
247
248
        if rope_scaling is not None and getattr(
                config, "original_max_position_embeddings", None):
            rope_scaling["original_max_position_embeddings"] = (
                config.original_max_position_embeddings)
249
250
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
251
252
253
254
        # Support abacusai/Smaug-72B-v0.1 with attention_bias
        # Support internlm/internlm-7b with bias
        attention_bias = getattr(config, "attention_bias", False) or getattr(
            config, "bias", False)
255
256
257
258
259
        bias_o_proj = attention_bias
        # support internlm/internlm3-8b with qkv_bias
        if hasattr(config, 'qkv_bias'):
            attention_bias = config.qkv_bias

Woosuk Kwon's avatar
Woosuk Kwon committed
260
        self.self_attn = LlamaAttention(
261
            config=config,
262
263
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
264
265
            num_kv_heads=getattr(config, "num_key_value_heads",
                                 config.num_attention_heads),
266
            first_layer=first_layer,
267
268
269
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
270
            quant_config=quant_config,
271
            bias=attention_bias,
272
            bias_o_proj=bias_o_proj,
273
            cache_config=cache_config,
274
            prefix=f"{prefix}.self_attn",
Woosuk Kwon's avatar
Woosuk Kwon committed
275
276
        )
        self.mlp = LlamaMLP(
277
278
279
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
280
            quant_config=quant_config,
281
            bias=getattr(config, "mlp_bias", False),
282
            prefix=f"{prefix}.mlp",
283
            last_layer=last_layer,
Woosuk Kwon's avatar
Woosuk Kwon committed
284
        )
285
286
287
288
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)
289
290
291
        
        self.first_layer = first_layer
        self.last_layer = last_layer
Woosuk Kwon's avatar
Woosuk Kwon committed
292
293
294

    def forward(
        self,
295
        positions: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
296
        hidden_states: torch.Tensor,
297
298
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
299
300
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
301
        # Self Attention
302
303
304
305
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
306
            assert (hidden_states.shape == residual.shape)
307
308
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
309
310
311
312
313
314
315
316
317
        
        # Partition residual
        if self.first_layer:
            n_slices = get_tensor_model_parallel_world_size()
            residual_slices = torch.chunk(residual, n_slices, dim=0)
            my_residual = residual_slices[get_tensor_model_parallel_rank()]
        else:
            my_residual = residual
            
318
319
320
321
        hidden_states = self.self_attn(positions=positions,
                                       hidden_states=hidden_states,
                                       kv_cache=kv_cache,
                                       attn_metadata=attn_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
322
323

        # Fully Connected
324
325
326
        assert (hidden_states.shape == my_residual.shape)
        hidden_states, my_residual = self.post_attention_layernorm(
            hidden_states, my_residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
327
        hidden_states = self.mlp(hidden_states)
328
329
330
331
332
333
334
        
        if self.last_layer:
            residual = tensor_model_parallel_all_gather(my_residual, 0)
        else:
            residual = my_residual

        assert (hidden_states.shape == residual.shape)
335
        return hidden_states, residual
Woosuk Kwon's avatar
Woosuk Kwon committed
336
337


338
@support_torch_compile
Woosuk Kwon's avatar
Woosuk Kwon committed
339
340
class LlamaModel(nn.Module):

341
342
343
344
345
    def __init__(self,
                 *,
                 vllm_config: VllmConfig,
                 prefix: str = "",
                 layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer):
Woosuk Kwon's avatar
Woosuk Kwon committed
346
        super().__init__()
347
348
349
350
351
352

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config

Woosuk Kwon's avatar
Woosuk Kwon committed
353
        self.config = config
354
        self.quant_config = quant_config
Woosuk Kwon's avatar
Woosuk Kwon committed
355
        self.padding_idx = config.pad_token_id
356
357
358
359
        lora_vocab = (lora_config.lora_extra_vocab_size *
                      (lora_config.max_loras or 1)) if lora_config else 0
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size
360
361
362
363
364
365
        if get_pp_group().is_first_rank or (config.tie_word_embeddings
                                            and get_pp_group().is_last_rank):
            self.embed_tokens = VocabParallelEmbedding(
                self.vocab_size,
                config.hidden_size,
                org_num_embeddings=config.vocab_size,
366
                quant_config=quant_config,
367
368
369
            )
        else:
            self.embed_tokens = PPMissingLayer()
370
        self.start_layer, self.end_layer, self.layers = make_layers(
371
            config.num_hidden_layers,
372
373
374
            lambda prefix, first_layer, last_layer: layer_type(config=config,
                                      first_layer=first_layer,
                                      last_layer=last_layer,
375
376
377
                                      cache_config=cache_config,
                                      quant_config=quant_config,
                                      prefix=prefix),
378
379
            prefix=f"{prefix}.layers",
        )
380
381
382
383
        if get_pp_group().is_last_rank:
            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        else:
            self.norm = PPMissingLayer()
Woosuk Kwon's avatar
Woosuk Kwon committed
384

385
386
387
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
zhuwenwen's avatar
zhuwenwen committed
388
389
390
391
392
        
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
gaoqiong's avatar
gaoqiong committed
393
394
            
        self.tritonsingleton= W8a8GetCacheJSON()      
zhuwenwen's avatar
zhuwenwen committed
395
396
397
398
399
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
        # self.use_lm_nn = os.environ.get('LM_NN') == '1'
        self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
        self.use_fa_pad = os.environ.get('FA_PAD') == '1'
        self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
zhuwenwen's avatar
zhuwenwen committed
400
        self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
401

402
403
404
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

Woosuk Kwon's avatar
Woosuk Kwon committed
405
406
    def forward(
        self,
407
        input_ids: Optional[torch.Tensor],
408
        positions: torch.Tensor,
409
410
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
411
        intermediate_tensors: Optional[IntermediateTensors],
412
        inputs_embeds: Optional[torch.Tensor] = None,
413
414
415
416
417
418
419
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
            residual = None
420
        else:
421
422
423
424
425
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

        for i in range(self.start_layer, self.end_layer):
Woosuk Kwon's avatar
Woosuk Kwon committed
426
            layer = self.layers[i]
427
428
429
            hidden_states, residual = layer(positions, hidden_states,
                                            kv_caches[i - self.start_layer],
                                            attn_metadata, residual)
430
431
432
433
434
435
436

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })

437
        hidden_states, _ = self.norm(hidden_states, residual)
Woosuk Kwon's avatar
Woosuk Kwon committed
438
439
        return hidden_states

440
441
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
442
443
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
444
445
446
447
448
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
Zhuohan Li's avatar
Zhuohan Li committed
449
        ]
450
        params_dict = dict(self.named_parameters())
451
        loaded_params: Set[str] = set()
452
        for name, loaded_weight in weights:
zhuwenwen's avatar
zhuwenwen committed
453
454
            current_count = loaded_weight.current_count 
            total_count = loaded_weight.total_count
455
456
            if "rotary_emb.inv_freq" in name:
                continue
457
458
459
460
            if ("rotary_emb.cos_cached" in name
                    or "rotary_emb.sin_cached" in name):
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
461
                continue
462
463
            if (self.quant_config is not None and
                (scale_name := self.quant_config.get_cache_scale(name))):
464
                # Loading kv cache quantization scales
465
466
467
                param = params_dict[scale_name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
468
469
                loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
                                 loaded_weight[0])
470
                weight_loader(param, loaded_weight)
471
                loaded_params.add(scale_name)
472
                continue
473
474
475
476
477
            if "scale" in name:
                # Remapping the name of FP8 kv-scale.
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue
478
            for param_name, weight_name, shard_id in stacked_params_mapping:
Zhuohan Li's avatar
Zhuohan Li committed
479
                if weight_name not in name:
480
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
481
482
483
484
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
485
486
487
488

                if is_pp_missing_parameter(name, self):
                    continue

CHU Tianxiang's avatar
CHU Tianxiang committed
489
                param = params_dict[name]
490
491
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
492
                break
493
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
494
495
496
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
497
498
499
500

                if is_pp_missing_parameter(name, self):
                    continue

501
502
503
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
504
                weight_loader(param, loaded_weight)
505
            loaded_params.add(name)
gaoqiong's avatar
gaoqiong committed
506
            
zhuwenwen's avatar
zhuwenwen committed
507
        if self.use_llama_nn and self.quant_method is None and current_count==total_count:
gaoqiong's avatar
gaoqiong committed
508
509
510
511
            lay_key_words = [
                "self_attn.qkv_proj.weight",
                "self_attn.o_proj.weight",
                "mlp.gate_up_proj.weight",
512
                "mlp.down_proj.weight",
gaoqiong's avatar
gaoqiong committed
513
514
515
            ]
            combined_words = "|".join(lay_key_words)
            
zhuwenwen's avatar
zhuwenwen committed
516
517
            # lay_qkv_words = ["self_attn.qkv_proj.weight"]   
            # qkv_words = "|".join(lay_qkv_words)          
zhuwenwen's avatar
zhuwenwen committed
518
            
zhuwenwen's avatar
zhuwenwen committed
519
            # for layername, weight in params_dict.items():
zhuwenwen's avatar
zhuwenwen committed
520
521
            # for layername in loaded_params:
            for layername in params_dict.keys():
zhuwenwen's avatar
zhuwenwen committed
522
                weight = params_dict[layername]
zhuwenwen's avatar
zhuwenwen committed
523
                if "lm_head.weight" in layername and weight.shape[1] >= 4096:
524
525
                    lay_key_words.append("lm_head.weight")
                    combined_words = "|".join(lay_key_words)
526
527
528
                    os.environ['LM_NN'] = '1'  
                else:
                    os.environ['LM_NN'] = '0' 
529
                    
gaoqiong's avatar
gaoqiong committed
530
                matches = re.findall(combined_words, layername)
531
                
532
                if matches:         
zhuwenwen's avatar
zhuwenwen committed
533
534
                    # if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
                    #     weight.data = pad_weight(weight.data, 32)  
535
                        
zhuwenwen's avatar
zhuwenwen committed
536
537
538
                    # if self.use_fa_pad and (re.findall(qkv_words, layername)):
                    #     if not gemm_bank_conf(weight.data.shape[0]):
                    #         weight.data = pad_weight(weight.data, 32)
539
                                 
gaoqiong's avatar
gaoqiong committed
540
541
542
                    _weight = torch.zeros_like(weight.data)
                    ori_shape =_weight.shape
                    
zhuwenwen's avatar
zhuwenwen committed
543
                    ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
gaoqiong's avatar
gaoqiong committed
544
545
                    weight.data.copy_(_weight)
                    
zhuwenwen's avatar
zhuwenwen committed
546
                    weight.data=weight.data.reshape(ori_shape[1], -1)
547
        else:
zhuwenwen's avatar
zhuwenwen committed
548
            os.environ['LM_NN'] = '0'
549
550
551
            os.environ['LLAMA_NN'] = '0'
            
        if self.quant_method == "awq" and not envs.VLLM_USE_TRITON_AWQ:
gaoqiong's avatar
gaoqiong committed
552
553
554
555
556
557
558
559
            lay_key_words = [
                "self_attn.qkv_proj.qweight",
                "self_attn.o_proj.qweight",
                "mlp.gate_up_proj.qweight",
                "mlp.down_proj.qweight"
            ]
            combined_words = "|".join(lay_key_words)
            
zhuwenwen's avatar
zhuwenwen committed
560
561
            for layername in loaded_params:
                weight = params_dict[layername]
gaoqiong's avatar
gaoqiong committed
562
563
564
565
566
567
568
569
570
571
572
573
574
575
                
                matches = re.findall(combined_words, layername)
                if matches:
                    qweight =params_dict[layername]
                    qzeros=params_dict[layername.replace("qweight", "qzeros")]
                    scales=params_dict[layername.replace("qweight", "scales")]
                    zeros_and_scalse =params_dict[layername.replace("qweight", "zeros_and_scales")]
                    
                    group_size= self.quant_config.group_size 
                   
                    dim_n = scales.data.shape[1]
                    dim_k = qweight.data.shape[0]
                    pad_group=2              
                    
gaoqiong's avatar
gaoqiong committed
576
                    _qw, _sz=ops.convert_s4(qweight,qzeros,scales,int(group_size)) 
gaoqiong's avatar
gaoqiong committed
577
                    
gaoqiong's avatar
gaoqiong committed
578
                    sz = ops.sz_permute(_sz).reshape(-1,dim_n)       
gaoqiong's avatar
gaoqiong committed
579
580
581
                    
                    zeros_and_scalse.data.copy_(sz)
                    qweight.data.copy_(_qw)
gaoqiong's avatar
gaoqiong committed
582
                    
gaoqiong's avatar
gaoqiong committed
583
584
585
586
                    #reshape
                    zeros_and_scalse.data=zeros_and_scalse.reshape(dim_n,-1)    #[k/greop_size,n]------>[n,k/group_size]
                    qweight.data=qweight.data.reshape(dim_n,-1)                      #[k,n/8]---->[n,k/8]  
                
zhuwenwen's avatar
zhuwenwen committed
587
                    if dim_k % 4096==0 and self.use_awq_pad:
gaoqiong's avatar
gaoqiong committed
588
589
590
591
                        zeros_and_scalse_pad= torch.zeros(dim_n,pad_group,dtype=torch.int32).cuda()
                        zeros_and_scalse.data=torch.cat((zeros_and_scalse.data,zeros_and_scalse_pad),dim=1).contiguous()
                        qweight_pad= torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda()
                        qweight.data=torch.cat((qweight.data,qweight_pad),dim=1).contiguous()
zhuwenwen's avatar
zhuwenwen committed
592
            
gaoqiong's avatar
gaoqiong committed
593
        #当为triton支持推理的时候不能进行处理
zhuwenwen's avatar
zhuwenwen committed
594
595
596
597
598
599
600
601
        if self.quant_method == "compressed_tensors":
            lay_key_words = [
                "self_attn.qkv_proj.weight",
                "self_attn.o_proj.weight",
                "mlp.gate_up_proj.weight",
                "mlp.down_proj.weight",
            ]
            combined_words = "|".join(lay_key_words)
gaoqiong's avatar
gaoqiong committed
602
603
            weight_shapes=[]
            all_json={}
604
            matched_key_words=set()
zhuwenwen's avatar
zhuwenwen committed
605
606
607
            
            for layername, weight in params_dict.items():  
                matches = re.findall(combined_words, layername)
gaoqiong's avatar
gaoqiong committed
608
                if matches and "scale" not in layername:
zhuwenwen's avatar
zhuwenwen committed
609
                    weight_data =params_dict[layername]
gaoqiong's avatar
gaoqiong committed
610
                    n=weight_data.shape[0]
zhuwenwen's avatar
zhuwenwen committed
611
612
613
614
                    # k=weight_data.shape[1]
                    
                    # #判断当前size是否在优化的范围内,假如存在则走triton,假如不存在则走rocblas
                    # json_file=self.tritonsingleton.get_w8a8json_name(n,k)
gaoqiong's avatar
gaoqiong committed
615
616
617
618
619
620
621
                    
                    #rocblas和cutlass目前都需要weight做处理,但是triton不用
                    if self.w8a8_strategy!=1:
                        _weight=weight_data.T.contiguous().reshape(n,-1)
                        weight_data.data.copy_(_weight)  
                    
                    #下面是针对模型记录模型出现k和n值 
622
623
                    elif len(matched_key_words) < 4 and matches[0] not in matched_key_words:
                        matched_key_words.add(matches[0])
gaoqiong's avatar
gaoqiong committed
624
625
626
627
628
                        k=weight_data.shape[1]
                        weight_shapes.append({n,k})
                
                        json_file=self.tritonsingleton.get_w8a8json_name(n,k)
                        configs_dict=self.tritonsingleton.get_triton_cache(json_file,n,k)
gaoqiong's avatar
gaoqiong committed
629
630
                        if configs_dict:
                            all_json.update(configs_dict)
gaoqiong's avatar
gaoqiong committed
631
632
633
634
635
636
637
638
                                              
            if self.w8a8_strategy==1:
                self.tritonsingleton.triton_json_dict.append(all_json)
                #找到的所有config都进行一次warmup
                for key, value in all_json.items():
                    m=int(key.split('_')[0])
                    n=int(key.split('_')[1])
                    k=int(key.split('_')[2])
gaoqiong's avatar
gaoqiong committed
639
                    ops.triton_int8_gemm_helper(m=m,n=n,k=k,per_token_act_quant=True,per_out_channel_weight_quant=True,use_bias=False,best_config=value)
zhuwenwen's avatar
zhuwenwen committed
640
                    
641
        return loaded_params
642

Woosuk Kwon's avatar
Woosuk Kwon committed
643

644
class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
Terry's avatar
Terry committed
645
    packed_modules_mapping = {
646
647
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"]
Terry's avatar
Terry committed
648
649
650
651
    }

    # LoRA specific attributes
    supported_lora_modules = [
652
653
        "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens",
        "lm_head"
Terry's avatar
Terry committed
654
655
656
    ]
    embedding_modules = {
        "embed_tokens": "input_embeddings",
657
        "lm_head": "output_embeddings"
Terry's avatar
Terry committed
658
659
    }
    embedding_padding_modules = ["lm_head"]
660

661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
    # Mistral/Llama models can also be loaded with --load-format mistral
    # from consolidated.safetensors checkpoints
    mistral_mapping = {
        "layers": "model.layers",
        "attention": "self_attn",
        "wq": "q_proj",
        "wk": "k_proj",
        "wv": "v_proj",
        "wo": "o_proj",
        "attention_norm": "input_layernorm",
        "feed_forward": "mlp",
        "w1": "gate_proj",
        "w2": "down_proj",
        "w3": "up_proj",
        "ffn_norm": "post_attention_layernorm",
        "tok_embeddings": "model.embed_tokens",
        "output": "lm_head",
        "norm": "model.norm"
    }
680

681
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Woosuk Kwon's avatar
Woosuk Kwon committed
682
        super().__init__()
683
684
685
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
Woosuk Kwon's avatar
Woosuk Kwon committed
686
        self.config = config
687
        self.lora_config = lora_config
zhuwenwen's avatar
zhuwenwen committed
688
        
689
690
        self.model = self._init_model(vllm_config=vllm_config,
                                      prefix=maybe_prefix(prefix, "model"))
zhuwenwen's avatar
zhuwenwen committed
691
        
gaoqiong's avatar
gaoqiong committed
692
        
693

694
695
696
697
698
699
700
701
        if get_pp_group().is_last_rank:
            self.unpadded_vocab_size = config.vocab_size
            if lora_config:
                self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
            self.lm_head = ParallelLMHead(
                self.unpadded_vocab_size,
                config.hidden_size,
                org_num_embeddings=config.vocab_size,
702
703
704
705
706
707
                padding_size=(
                    DEFAULT_VOCAB_PADDING_SIZE
                    # We need bigger padding if using lora for kernel
                    # compatibility
                    if not lora_config else
                    lora_config.lora_vocab_padding_size),
708
                quant_config=quant_config,
709
                prefix=maybe_prefix(prefix, "lm_head"),
710
711
            )
            if config.tie_word_embeddings:
712
713
                self.lm_head = self.lm_head.tie_weights(
                    self.model.embed_tokens)
714
715
716
717
718
719
720

            logit_scale = getattr(config, "logit_scale", 1.0)
            self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                    config.vocab_size,
                                                    logit_scale)
        else:
            self.lm_head = PPMissingLayer()
721

722
723
        self.sampler = get_sampler()

724
725
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
zhuwenwen's avatar
zhuwenwen committed
726
        
Woosuk Kwon's avatar
Woosuk Kwon committed
727

728
729
730
    def _init_model(self, vllm_config: VllmConfig, prefix: str = ""):
        return LlamaModel(vllm_config=vllm_config, prefix=prefix)

731
732
733
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

Woosuk Kwon's avatar
Woosuk Kwon committed
734
735
    def forward(
        self,
736
737
        input_ids: torch.Tensor,
        positions: torch.Tensor,
738
739
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
740
        intermediate_tensors: Optional[IntermediateTensors] = None,
741
        inputs_embeds: Optional[torch.Tensor] = None,
742
743
    ) -> Union[torch.Tensor, IntermediateTensors]:
        model_output = self.model(input_ids, positions, kv_caches,
744
745
                                  attn_metadata, intermediate_tensors,
                                  inputs_embeds)
746
        return model_output
747

748
749
750
751
752
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
753
        logits = self.logits_processor(self.lm_head, hidden_states,
754
755
756
                                       sampling_metadata)
        return logits

757
758
    def sample(self, logits: torch.Tensor,
               sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
759
        next_tokens = self.sampler(logits, sampling_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
760
761
        return next_tokens

zhuwenwen's avatar
zhuwenwen committed
762

763
764
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
765
766
767
768
769
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=(["lm_head."]
                           if self.config.tie_word_embeddings else None),
        )
770
        return loader.load_weights(
771
            self.maybe_remap_mistral(name, loaded_weight)
772
            for name, loaded_weight in weights)
773

zhuwenwen's avatar
zhuwenwen committed
774

775
776
777
    # This function is used to remap the mistral format as
    # used by Mistral and Llama <=2
    def maybe_remap_mistral(
778
779
780
781
        self,
        name: str,
        loaded_weight: torch.Tensor,
    ) -> Tuple[str, torch.Tensor]:
782

783
        def permute(w: torch.Tensor, n_heads: int):
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
            attn_in = self.config.head_dim * n_heads
            attn_out = self.config.hidden_size

            return w.view(n_heads, attn_in // n_heads // 2, 2,
                          attn_out).transpose(1, 2).reshape(attn_in, attn_out)

        mapping = self.mistral_mapping
        modules = name.split(".")

        # rotary embeds should be sliced
        if "wk" in modules:
            loaded_weight = permute(loaded_weight,
                                    self.config.num_key_value_heads)
        elif "wq" in modules:
            loaded_weight = permute(loaded_weight,
                                    self.config.num_attention_heads)

        for item in modules:
            if item in mapping and mapping[item] not in name:
                name = name.replace(item, mapping[item])

805
        return name, loaded_weight