linear.py 74.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
from abc import abstractmethod
6
from typing import Any, Literal, Optional, Union
7
import vllm.envs as envs
8
import torch
9
import torch.nn as nn
10
from torch.nn.parameter import Parameter, UninitializedParameter
11

12
from vllm import envs
13
14
15
16
17
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size,
                              split_tensor_along_last_dim,
                              tensor_model_parallel_all_gather,
                              tensor_model_parallel_all_reduce)
18
from vllm.logger import init_logger
19
20
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig, QuantizeMethodBase)
21
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
22
# yapf: disable
23
from vllm.model_executor.parameter import (BasevLLMParameter,
24
                                           BlockQuantScaleParameter,
25
                                           PackedColumnParameter,
26
                                           PackedvLLMParameter,
27
28
                                           PerTensorScaleParameter,
                                           RowvLLMParameter)
29
# yapf: enable
30
from vllm.model_executor.utils import set_weight_attrs
31
from vllm.platforms import current_platform
gaoqiong's avatar
gaoqiong committed
32

zhuwenwen's avatar
zhuwenwen committed
33
import os
34
from vllm.model_executor.utils import gemm_bank_conf
35

36
37
38
39
40
if envs.USE_FUSED_RMS_QUANT:
    try:
        from lmslim.quantize.quant_ops import lm_faster_rmsquant
    except Exception as e:
        print(f"Error: Import fused rmsquant error: {e}") 
41
42
43
44
45
46
47
if envs.USE_FUSED_SILU_MUL_QUANT:        
    try:
        # from lightop import fuse_silu_mul_quant
        from lmslim.quantize.quant_ops import lm_fuse_silu_mul_quant
    except Exception as e:
        print(f"Error: Import fused silu_mul_qunat error: {e}")
        
48
49
logger = init_logger(__name__)

50
WEIGHT_LOADER_V2_SUPPORTED = [
51
    "CompressedTensorsLinearMethod",
52
53
    "BitBLASLinearMethod",
    "GPTQBitBLASLinearMethod",
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
    "AWQMarlinLinearMethod",
    "AWQLinearMethod",
    "GPTQMarlinLinearMethod",
    "Fp8LinearMethod",
    "MarlinLinearMethod",
    "QQQLinearMethod",
    "GPTQMarlin24LinearMethod",
    "TPUInt8LinearMethod",
    "GPTQLinearMethod",
    "FBGEMMFp8LinearMethod",
    "ModelOptFp8LinearMethod",
    "IPEXAWQLinearMethod",
    "IPEXGPTQLinearMethod",
    "HQQMarlinMethod",
    "QuarkLinearMethod",
    "ModelOptNvFp4LinearMethod",
zhuwenwen's avatar
zhuwenwen committed
70
    "BlockInt8LinearMethod",
71
]
72

73

74
75
76
77
78
79
80
81
82
def adjust_bitblas_shard(param, shard_size, shard_offset):
    bitblas_tile_size = getattr(param, "bitblas_tile_size", None)
    if bitblas_tile_size is not None:
        return (shard_size // bitblas_tile_size,
                shard_offset // bitblas_tile_size)

    return shard_size, shard_offset


83
84
85
86
87
88
89
90
def adjust_marlin_shard(param, shard_size, shard_offset):
    marlin_tile_size = getattr(param, "marlin_tile_size", None)
    if marlin_tile_size is None:
        return shard_size, shard_offset

    return shard_size * marlin_tile_size, shard_offset * marlin_tile_size


91
def adjust_bitsandbytes_4bit_shard(param: Parameter,
92
93
                                   shard_offsets: dict[str, tuple[int, int]],
                                   loaded_shard_id: str) -> tuple[int, int]:
94
95
    """Adjust the quantization offsets and sizes for BitsAndBytes sharding."""

96
97
    total, _ = shard_offsets["total"]
    orig_offset, orig_size = shard_offsets[loaded_shard_id]
98
99
100
101
102
103
104
105

    quantized_total = param.data.shape[0]
    quantized_offset = orig_offset * quantized_total // total
    quantized_size = orig_size * quantized_total // total

    return quantized_size, quantized_offset


106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
    """For fused modules (QKV and MLP) we have an array of length
    N that holds 1 scale for each "logical" matrix. So the param
    is an array of length N. The loaded_weight corresponds to 
    one of the shards on disk. Here, we slice the param based on 
    the shard_id for loading.
    """
    qkv_idxs = {"q": 0, "k": 1, "v": 2}

    if isinstance(shard_id, str):
        shard_id = qkv_idxs[shard_id]
    elif not isinstance(shard_id, int):
        raise ValueError(f"Unknown Shard Id {shard_id}")

    # AutoFP8 scales do not have a shape
    # compressed-tensors scales do have a shape
    if len(loaded_weight.shape) != 0:
        assert loaded_weight.shape[0] == 1
        loaded_weight = loaded_weight[0]

126
127
128
129
    if envs.VLLM_USE_NN:
        return param[shard_id], loaded_weight.t()
    else:
        return param[shard_id], loaded_weight
130
131


132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
# TODO(Isotr0py): We might need a more flexible structure to handle
# bitsandbytes shard offsets.
def left_shift_bitsandbytes_4bit_shard(bnb_weight_attrs: dict[str, Any]):
    """
    Separate the BitsAndBytes 4-bit shard.

    For example, given bnb weight attributes as below:
    {
        'bnb_shard_offsets': array([0, 4, 8, 16]), 
        'bnb_quant_state': {0: ..., 1: ..., 2: ...},
    }

    The function will return:
    {
        'bnb_shard_offsets': array([0, 4]), 
        'bnb_quant_state': {0: ...},
    }
    and
    {
        'bnb_shard_offsets': array([0, 4, 12]),
        'bnb_quant_state': {0: ..., 1: ...},
    }
    """
    shard_offsets = bnb_weight_attrs["bnb_shard_offsets"]
    offset_l = shard_offsets[:2]
    offset_r = shard_offsets[1:] - shard_offsets[1]
    quant_state_l = {0: bnb_weight_attrs["bnb_quant_state"][0]}
    quant_state_r = {
        i - 1: bnb_weight_attrs["bnb_quant_state"][i]
        for i in range(1,
                       len(shard_offsets) - 1)
    }
    left = dict(bnb_shard_offsets=offset_l, bnb_quant_state=quant_state_l)
    right = dict(bnb_shard_offsets=offset_r, bnb_quant_state=quant_state_r)
    return left, right


169
class LinearMethodBase(QuantizeMethodBase):
170
171
172
    """Base class for different (maybe quantized) linear methods."""

    @abstractmethod
173
174
    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
175
                       output_partition_sizes: list[int], input_size: int,
176
177
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
178
179
        """Create weights for a linear layer. 
           The weights will be set as attributes of the layer.
180

181
182
183
184
185
186
187
188
189
190
        Args:
            layer: The layer that is using the LinearMethodBase factory.
            input_size_per_partition: Size of the weight input dim on rank X.
            output_partition_sizes: Sizes of the output dim of each logical 
                weight on rank X. E.g., output_partition_sizes for QKVLinear
                is a list contains the width of Wq, Wk, Wv on rank X.
            input_size: Size of the input dim of the weight across all ranks.
            output_size: Size of the output dim of the weight across all ranks.
            params_dtype: Datatype of the parameters.
        """
191
192
193
        raise NotImplementedError

    @abstractmethod
194
195
196
197
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
198
199
        """Apply the weights in layer to the input tensor.
        Expects create_weights to have been called before on the layer."""
200
201
202
203
        raise NotImplementedError


class UnquantizedLinearMethod(LinearMethodBase):
204
    """Linear method without quantization."""
205
206
    
    def __init__(self):
zhuwenwen's avatar
zhuwenwen committed
207
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
208
        self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
gaoqiong's avatar
gaoqiong committed
209
        
210
211
    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
212
                       output_partition_sizes: list[int], input_size: int,
213
214
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
215
216
217
        if envs.VLLM_USE_NN:
            weight = Parameter(torch.empty(input_size_per_partition,
                                       sum(output_partition_sizes),
218
219
                                       dtype=params_dtype),
                           requires_grad=False)
220
221
222
223
224
        else:
            weight = Parameter(torch.empty(sum(output_partition_sizes),
                                           input_size_per_partition,
                                           dtype=params_dtype),
                           requires_grad=False)
225
        set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
226
227
        layer.register_parameter("weight", weight)
        set_weight_attrs(weight, extra_weight_attrs)
228

229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        if current_platform.is_cpu() and envs.VLLM_CPU_SGL_KERNEL:
            N, K = layer.weight.size()
            dtype = layer.weight.dtype
            if (torch._C._cpu._is_amx_tile_supported()
                    and dtype == torch.bfloat16 and N % 32 == 0
                    and K % 32 == 0):
                packed_weight = torch.ops._C.convert_weight_packed(
                    layer.weight)
                assert packed_weight.size() == layer.weight.size()
                layer.weight.copy_(packed_weight)
                if layer.bias is not None:
                    layer.bias = Parameter(layer.bias.to(torch.float32),
                                           requires_grad=False)
                layer.use_cpu_sgl = True
            else:
                logger.warning(
                    "CPU SGL kernels require Intel AMX support,"
                    " bfloat16 weight, IC and OC are divisible by 32.")
                layer.use_cpu_sgl = False

250
251
252
253
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
zhuwenwen's avatar
zhuwenwen committed
254
        if self.use_llama_nn:
255
256
            if gemm_bank_conf(layer.weight.shape[1] - 32) and os.environ['GEMM_PAD'] == '1':
                layer.weight = layer.weight[:,:-32]
257
                
zhuwenwen's avatar
zhuwenwen committed
258
            if bias is not None:
zhuwenwen's avatar
zhuwenwen committed
259
                if len(x.shape) == 2: 
260
                    return torch.addmm(bias, x, layer.weight)
zhuwenwen's avatar
zhuwenwen committed
261
                else:
262
                    return torch.matmul(x, layer.weight) + bias
zhuwenwen's avatar
zhuwenwen committed
263
            else:
264
                return torch.matmul(x, layer.weight)
zhuwenwen's avatar
zhuwenwen committed
265
        else:
266
267
268
269
            if envs.VLLM_USE_NN and x.shape[-1] == layer.weight.shape[0]:
                return dispatch_unquantized_gemm()(x, layer.weight.t(), bias)
            else:
                return dispatch_unquantized_gemm()(x, layer.weight, bias)
270

271

272
273
class LinearBase(torch.nn.Module):
    """Base linear layer.
274
275
276
277
278
279
280

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        bias: If true, add bias.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
281
        quant_config: Quantization configure.
282
        return_bias: If true, return bias together with outputs in forward pass.
283
284
285
286
287
288
289
290
    """

    def __init__(
        self,
        input_size: int,
        output_size: int,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
291
        quant_config: Optional[QuantizationConfig] = None,
292
        prefix: str = "",
293
294
        *,
        return_bias: bool = True,
295
296
297
298
299
300
301
302
303
304
    ):
        super().__init__()

        # Keep input parameters
        self.input_size = input_size
        self.output_size = output_size
        self.skip_bias_add = skip_bias_add
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype
305
        if quant_config is None:
306
307
            self.quant_method: Optional[
                QuantizeMethodBase] = UnquantizedLinearMethod()
308
        else:
309
310
            self.quant_method = quant_config.get_quant_method(self,
                                                              prefix=prefix)
311
        self.return_bias = return_bias
312

313
314
315
    def forward(
        self, x: torch.Tensor
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
316
317
318
319
320
321
322
323
324
325
326
327
328
        raise NotImplementedError


class ReplicatedLinear(LinearBase):
    """Replicated linear layer.

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        bias: If true, add bias.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
        quant_config: Quantization configure.
329
330
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
331
        return_bias: If true, return bias together with outputs in forward pass.
332
333
    """

334
335
336
337
338
339
340
341
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
342
        eps: Optional[float] = 1e-6,
343
344
345
346
        prefix: str = "",
        *,
        return_bias: bool = True,
    ):
347
348
349
350
351
        super().__init__(input_size,
                         output_size,
                         skip_bias_add,
                         params_dtype,
                         quant_config,
352
353
                         prefix=prefix,
                         return_bias=return_bias)
354
        self.eps = eps
355

356
357
        # All the linear layer supports quant method.
        assert self.quant_method is not None
358
359
360
361
362
        self.quant_method.create_weights(self,
                                         self.input_size, [self.output_size],
                                         self.input_size,
                                         self.output_size,
                                         self.params_dtype,
363
                                         weight_loader=self.weight_loader)
364

365
366
        if bias:
            self.bias = Parameter(
367
                torch.empty(self.output_size, dtype=self.params_dtype))
368
369
370
371
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
372
373
374
        else:
            self.register_parameter("bias", None)

375
376
377
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        # If the weight on disk does not have a shape, give it one
        # (such scales for AutoFp8).
378
379
380
381
382
383
384
385
386
387
388
        # Special case for GGUF

        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
            param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)

389
390
391
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)

392
393
394
395
        is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
        if envs.VLLM_USE_NN and not is_quantization:
            loaded_weight = loaded_weight.t()
            
396
397
398
        assert param.size() == loaded_weight.size(), (
            f"Tried to load weights of size {loaded_weight.size()}"
            f"to a parameter of size {param.size()}")
399
400
        param.data.copy_(loaded_weight)

401
    def forward(
402
403
404
405
406
407
        self, 
        input_: torch.Tensor,
        rms_weight: Optional[torch.Tensor] = None,
        residual: Optional[torch.Tensor] = None,
        quant_args: Optional[list] = None,
        update_hd: Optional[bool] = True
408
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
        if envs.USE_FUSED_RMS_QUANT and (rms_weight is not None or quant_args is not None):
            if quant_args is not None:
                input_quant_args = quant_args
            
                bias = self.bias if not self.skip_bias_add else None
                assert self.quant_method is not None
                output = self.quant_method.apply(self, input_, bias, input_quant_args)
                output_bias = self.bias if self.skip_bias_add else None
                if not self.return_bias:
                    return output
                return output, output_bias

            else:
                i_q, _scales = lm_faster_rmsquant(input=input_,
                                                  rms_weight=rms_weight,
                                                  epsilon=self.eps,
                                                  quant_dtype=torch.int8,
                                                  residual=residual,
                                                  update_input=update_hd
                                                  )
            
                new_residual = residual
                input_quant_args = [i_q, _scales]
                
                bias = self.bias if not self.skip_bias_add else None
                assert self.quant_method is not None
                output = self.quant_method.apply(self, input_, bias, input_quant_args)
                output_bias = self.bias if self.skip_bias_add else None
                if not self.return_bias:
                    return output
                return output, new_residual, output_bias, input_quant_args

        else:
            bias = self.bias if not self.skip_bias_add else None
            assert self.quant_method is not None
            output = self.quant_method.apply(self, input_, bias)
            output_bias = self.bias if self.skip_bias_add else None
            if not self.return_bias:
                return output
            return output, output_bias
449

450
451
452
453
454
455
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        return s

456

457
class ColumnParallelLinear(LinearBase):
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
    """Linear layer with column parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its second dimension as A = [A_1, ..., A_p].

    Args:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y_i = XA_i
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
474
        quant_config: Quantization configure.
James Fleming's avatar
James Fleming committed
475
476
        output_sizes: list of output sizes packed into one output, like for QKV
                       the list would be size 3.
477
478
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj) 
479
480
    """

481
482
483
484
485
486
487
488
489
490
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
        output_sizes: Optional[list[int]] = None,
491
        eps: Optional[float] = 1e-6,
492
493
494
495
        prefix: str = "",
        *,
        return_bias: bool = True,
    ):
496
        # Divide the weight matrix along the last dimension.
497
498
499
        self.tp_size = get_tensor_model_parallel_world_size()
        self.input_size_per_partition = input_size
        self.output_size_per_partition = divide(output_size, self.tp_size)
500
501
502
503
        self.output_partition_sizes = [self.output_size_per_partition]
        # If QKV or MergedColumn, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = [
504
                divide(output_size, self.tp_size)
505
506
507
                for output_size in self.output_sizes
            ]

508
509
510
511
512
513
514
        super().__init__(input_size,
                         output_size,
                         skip_bias_add,
                         params_dtype,
                         quant_config,
                         prefix,
                         return_bias=return_bias)
515
        self.eps = eps
516
517
        self.gather_output = gather_output

James Fleming's avatar
James Fleming committed
518
519
        if output_sizes is None:
            output_sizes = [output_size]
520

521
        assert self.quant_method is not None
522
523
        self.quant_method.create_weights(
            layer=self,
524
            input_size_per_partition=self.input_size_per_partition,
525
526
527
528
            output_partition_sizes=self.output_partition_sizes,
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
529
530
            weight_loader=(
                self.weight_loader_v2 if self.quant_method.__class__.__name__
531
                in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
532
533
534
535
536
537
538
539
540
541
542
543
544
545
        if bias:
            self.bias = Parameter(
                torch.empty(self.output_size_per_partition,
                            dtype=params_dtype))
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
        else:
            self.register_parameter("bias", None)

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        tp_rank = get_tensor_model_parallel_rank()
        output_dim = getattr(param, "output_dim", None)
546

547
548
549
550
551
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
552
        is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
553

554
555
556
557
558
559
560
561
        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
562
563
564
565
566
567
            final_shape = list(loaded_weight.shape)
            if output_dim is not None:
                tp_size = get_tensor_model_parallel_world_size()
                assert final_shape[output_dim] % tp_size == 0
                final_shape[output_dim] = final_shape[output_dim] // tp_size
            param.materialize(final_shape, dtype=loaded_weight.dtype)
568

569
        param_data = param.data
570
        if output_dim is not None and not is_sharded_weight:
571
572
573
574
            if not envs.VLLM_USE_NN or len(param_data.shape)==1 or is_quantization:
                shard_size = param_data.shape[output_dim] 
            else:
                shard_size = param_data.shape[int(not(output_dim))]
575
576
577
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                 shard_size)
578
579
580
581
582

        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)
583

584
585
586
        if envs.VLLM_USE_NN and not is_quantization:
            loaded_weight = loaded_weight.t()
            
587
588
589
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

590
    def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor):
591
592
593
594
595
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)
596
597
        param.load_column_parallel_weight(loaded_weight=loaded_weight)

598
    def forward(
599
600
601
602
        self, input_,
        rms_weight: Optional[torch.Tensor] = None,
        residual: Optional[torch.Tensor] = None,
        update_hd: Optional[bool] = True
603
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
604
605
606
607
608
609
610
611
612
613
614
615
616
        if envs.USE_FUSED_RMS_QUANT and rms_weight is not None:
            input_quant_args = None
            assert rms_weight is not None 
            i_q, _scales = lm_faster_rmsquant(input=input_,
                                        rms_weight=rms_weight,
                                        epsilon=self.eps,
                                        quant_dtype=torch.int8,
                                        residual=residual,
                                        update_input=update_hd)
            new_residual = residual
            input_quant_args = [i_q, _scales]
        
            bias = self.bias if not self.skip_bias_add else None
617

618
619
620
621
622
623
624
625
626
627
            assert self.quant_method is not None
            output_parallel = self.quant_method.apply(self, input_, bias, input_quant_args)
            if self.gather_output:
                output = tensor_model_parallel_all_gather(output_parallel)
            else:
                output = output_parallel
            output_bias = self.bias if self.skip_bias_add else None
            if not self.return_bias:
                return output
            return output, new_residual, output_bias
628
        else:
629
630
631
632
633
634
635
636
637
638
639
640
641
            bias = self.bias if not self.skip_bias_add else None
            # Matrix multiply.
            assert self.quant_method is not None
            output_parallel = self.quant_method.apply(self, input_, bias)
            if self.gather_output:
                # All-gather across the partitions.
                output = tensor_model_parallel_all_gather(output_parallel)
            else:
                output = output_parallel
            output_bias = self.bias if self.skip_bias_add else None
            if not self.return_bias:
                return output
            return output, output_bias
642

643
644
645
646
647
648
649
650
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size_per_partition}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={get_tensor_model_parallel_world_size()}"
        s += f", gather_output={self.gather_output}"
        return s

651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669

class MergedColumnParallelLinear(ColumnParallelLinear):
    """Packed linear layers with column parallelism.

    Similar to ColumnParallelLinear, but the weight matrix is concatenated
    along the output dimension. When the weight matrix is loaded, the
    different partitions are sharded separately.

    Args:
        input_size: input dimension of the linear layer.
        output_sizes: list of output dimensions of the linear layer.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make the output
                       available to all GPUs, otherwise, every GPU will have
                       its own output.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
670
        quant_config: Quantization configure.
671
672
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
673
        return_bias: If true, return bias together with outputs in forward pass.
674
675
    """

676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
    def forward(
        self, input_,
        rms_weight: Optional[torch.Tensor] = None,
        residual: Optional[torch.Tensor] = None,
        update_hd: Optional[bool] = True
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
        if envs.USE_FUSED_RMS_QUANT and rms_weight is not None:
            input_quant_args = None
            assert residual is not None and rms_weight is not None 
            i_q, _scales = lm_faster_rmsquant(input=input_,
                                        rms_weight=rms_weight,
                                        epsilon=self.eps,
                                        quant_dtype=torch.int8,
                                        residual=residual,
                                        update_input=update_hd)
            
            new_residual = residual
            input_quant_args = [i_q, _scales]
            
            
            bias = self.bias if not self.skip_bias_add else None
            assert self.quant_method is not None
            output_parallel = self.quant_method.apply(self, input_, bias, input_quant_args)
            
            if self.gather_output:
                # All-gather across the partitions.
                output = tensor_model_parallel_all_gather(output_parallel)
            else:
                output = output_parallel
            output_bias = self.bias if self.skip_bias_add else None
            if not self.return_bias:
                return output
            return output, new_residual, output_bias
        else: # not USE_FUSED_RMS_QUANT
            bias = self.bias if not self.skip_bias_add else None

            assert self.quant_method is not None
            output_parallel = self.quant_method.apply(self, input_, bias)
            if self.gather_output:
                # All-gather across the partitions.
                output = tensor_model_parallel_all_gather(output_parallel)
            else:
                output = output_parallel
            output_bias = self.bias if self.skip_bias_add else None
            if not self.return_bias:
                return output
            return output, output_bias
    
724
725
726
727
728
729
730
731
732
    def __init__(
        self,
        input_size: int,
        output_sizes: list[int],
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
733
        eps: Optional[float] = 1e-6,
734
735
736
737
        prefix: str = "",
        *,
        return_bias: bool = True,
    ):
738
        self.eps = eps
739
740
741
        self.output_sizes = output_sizes
        tp_size = get_tensor_model_parallel_world_size()
        assert all(output_size % tp_size == 0 for output_size in output_sizes)
742
743
744
745
746
747
        super().__init__(input_size=input_size,
                         output_size=sum(output_sizes),
                         bias=bias,
                         gather_output=gather_output,
                         skip_bias_add=skip_bias_add,
                         params_dtype=params_dtype,
748
                         quant_config=quant_config,
749
750
                         prefix=prefix,
                         return_bias=return_bias)
751
752
753
754
755

    def weight_loader(self,
                      param: Parameter,
                      loaded_weight: torch.Tensor,
                      loaded_shard_id: Optional[int] = None):
James Fleming's avatar
James Fleming committed
756

757
758
759
760
761
        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
762
763
764
765
766
767
768
769
            if loaded_shard_id is not None:
                param.data[loaded_shard_id].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
                param.shard_weight_type = {
                    i: loaded_weight.item()
                    for i, _ in enumerate(self.output_sizes)
                }
770
771
            return

772
773
774
        if is_gguf_weight:
            tp_size = get_tensor_model_parallel_world_size()
            tp_rank = get_tensor_model_parallel_rank()
775

776
777
778
            output_dim = getattr(param, "output_dim", None)
            shard_size = loaded_weight.size(output_dim) // tp_size
            start_idx = tp_rank * shard_size
779

780
781
782
783
784
785
786
            if loaded_shard_id is not None:
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
787

788
789
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
790
        # Special case for AQLM codebooks.
James Fleming's avatar
James Fleming committed
791
        is_metadata = getattr(param, "is_metadata", False)
792
793
        # Special case for per-tensor scale to load scalar into fused array.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
794
        is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
795

796
        if loaded_shard_id is None:
797
798
            # Loaded weight is already fused on disk (mlp).
            # (e.g., Phi-3's gate_up_proj).
799
            if output_dim is None:
800
                if needs_scalar_to_array:
801
802
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
                        param_data, loaded_weight, 0)
803

804
805
806
807
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            current_shard_offset = 0
808
809
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)
810
            shard_offsets: list[tuple[int, int, int]] = []
811
812
813
814
815
            for i, output_size in enumerate(self.output_sizes):
                shard_offsets.append((i, current_shard_offset, output_size))
                current_shard_offset += output_size
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
816
                # Special case for Quantization.
817
818
819
820
821
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
                    shard_size = shard_size // param.pack_factor
                    shard_offset = shard_offset // param.pack_factor
822
                    # Special case for Marlin.
823
824
825
                    shard_size, shard_offset = adjust_marlin_shard(
                        param, shard_size, shard_offset)

826
827
828
                shard_size, shard_offset = adjust_bitblas_shard(
                    param, shard_size, shard_offset)

829
                if use_bitsandbytes_4bit:
830
831
832
833
834
835
836
837
838
                    index = list(itertools.accumulate([0] + self.output_sizes))
                    orig_offsets = {
                        str(i): (index[i], size)
                        for i, size in enumerate(self.output_sizes)
                    }
                    orig_offsets["total"] = (self.output_size, 0)
                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
                        param, orig_offsets, str(shard_id))

839
840
841
842
843
844
845
846
847
848
849
                loaded_weight_shard = loaded_weight.narrow(
                    output_dim, shard_offset, shard_size)
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id < len(self.output_sizes)
        tp_rank = get_tensor_model_parallel_rank()
        tp_size = get_tensor_model_parallel_world_size()
        if output_dim is not None:
            shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
            shard_size = self.output_sizes[loaded_shard_id] // tp_size
850
            # Special case for quantization.
851
852
853
854
855
856
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
                shard_size = shard_size // param.pack_factor
                shard_offset = shard_offset // param.pack_factor
857
                # Special case for Marlin.
858
859
                shard_size, shard_offset = adjust_marlin_shard(
                    param, shard_size, shard_offset)
860
861
            shard_size, shard_offset = adjust_bitblas_shard(
                param, shard_size, shard_offset)
gaoqiong's avatar
gaoqiong committed
862

863
864
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)
865
866
867
868
869
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

870
            if use_bitsandbytes_4bit:
871
872
873
                shard_size = loaded_weight.shape[output_dim]
                shard_offset = loaded_weight.shape[output_dim] * \
                    loaded_shard_id
874
875
876
877
878
                    
            if not envs.VLLM_USE_NN or is_quantization:
                param_data = param_data.narrow(output_dim, shard_offset, shard_size)
            else:
                param_data = param_data.narrow(int(not(output_dim)), shard_offset, shard_size)
879

880
            start_idx = tp_rank * shard_size
881
            if not is_sharded_weight:
882
883
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)
884
        # Special case for AQLM codebooks.
James Fleming's avatar
James Fleming committed
885
886
887
888
889
        elif is_metadata:
            # metadata indicates fixed size concatenated along dim 0
            shard_size = loaded_weight.shape[0]
            shard_offset = loaded_shard_id * shard_size
            param_data = param_data.narrow(0, shard_offset, shard_size)
890

891
892
893
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
894
895
                param_data, loaded_weight, loaded_shard_id)

896
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
897
898
899
900
901
902
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "MergedColumnParallelLinear, assume the weight is "
                    "the same for all partitions.")
903

904
905
906
        if envs.VLLM_USE_NN and not is_quantization:
            loaded_weight = loaded_weight.t()
            
gaoqiong's avatar
gaoqiong committed
907
908
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)
909

910
911
912
913
914
915
916
917
918
919
920
921
922
    def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
                                           loaded_weight: torch.Tensor):
        """
        Handle special case for models where MLP layers are already
        fused on disk. In this case, we have no shard id. This function
        determmines the shard id by splitting these layers and then calls
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """

        current_shard_offset = 0
923
        shard_offsets: list[tuple[int, int, int]] = []
924
925
926
927
928
929
930
931
        for i, output_size in enumerate(self.output_sizes):
            shard_offsets.append((i, current_shard_offset, output_size))
            current_shard_offset += output_size

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
932
933
            if isinstance(param, (PackedColumnParameter, PackedvLLMParameter
                                  )) and param.packed_dim == param.output_dim:
934
935
                shard_size, shard_offset = \
                    param.adjust_shard_indexes_for_packing(
936
937
938
939
940
941
942
943
944
945
946
947
                    shard_size=shard_size, shard_offset=shard_offset)

            loaded_weight_shard = loaded_weight.narrow(param.output_dim,
                                                       shard_offset,
                                                       shard_size)
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

    def weight_loader_v2(self,
                         param: BasevLLMParameter,
                         loaded_weight: torch.Tensor,
                         loaded_shard_id: Optional[int] = None):
        if loaded_shard_id is None:
948
949
950
951
            if isinstance(param, PerTensorScaleParameter):
                param.load_merged_column_weight(loaded_weight=loaded_weight,
                                                shard_id=0)
                return
952
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
953
                param.load_merged_column_weight(loaded_weight=loaded_weight)
954
                return
955
            # TODO: @dsikka - move to parameter.py
956
957
958
959
960
961
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id < len(self.output_sizes)

        tp_size = get_tensor_model_parallel_world_size()
962
963
964
965

        if isinstance(param, BlockQuantScaleParameter):
            from vllm.model_executor.layers.quantization.fp8 import (
                Fp8LinearMethod, Fp8MoEMethod)
966
967
968
            
            from vllm.model_executor.layers.quantization.blockwise_int8 import (
                BlockInt8LinearMethod, BlockInt8MoEMethod)
969
970
            assert self.quant_method is not None
            assert isinstance(self.quant_method,
971
                              (Fp8LinearMethod, Fp8MoEMethod, BlockInt8LinearMethod, BlockInt8MoEMethod))
972
973
974
975
976
977
978
979
980
981
982
            weight_block_size = self.quant_method.quant_config.weight_block_size
            assert weight_block_size is not None
            block_n, _ = weight_block_size[0], weight_block_size[1]
            shard_offset = (
                (sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) //
                block_n) // tp_size
            shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) //
                          block_n // tp_size)
        else:
            shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
            shard_size = self.output_sizes[loaded_shard_id] // tp_size
983
984
985
986
987
988

        param.load_merged_column_weight(loaded_weight=loaded_weight,
                                        shard_id=loaded_shard_id,
                                        shard_offset=shard_offset,
                                        shard_size=shard_size)

989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
class QKVParallelLinear(ColumnParallelLinear):
    """Linear layers for the attention's QKV transformation.

    Linear layers for the linear transformation of the query, key, and value
    vectors in the attention layer. The weight matrix is concatenated along
    the output dimension. The layer is parallelized along the head dimension.
    When the number of key/value heads is smaller than the number of query
    heads (e.g., multi-query/grouped-query attention), the key/value head may
    be replicated while the query heads are partitioned.

    Args:
        hidden_size: input hidden state size of the transformer.
        head_size: size of each attention head.
        total_num_heads: total number of attention query heads.
        total_num_kv_heads: total number of attention key/value heads. If
                            None, assume total_num_kv_heads = total_num_heads.
        bias: If true, add bias.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
1010
        quant_config: Quantization configure.
1011
1012
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
1013
        return_bias: If true, return bias together with outputs in forward pass.
1014
1015
    """

1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
    def __init__(
        self,
        hidden_size: int,
        head_size: int,
        total_num_heads: int,
        total_num_kv_heads: Optional[int] = None,
        bias: bool = True,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        *,
        return_bias: bool = True,
    ):
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
        self.hidden_size = hidden_size
        self.head_size = head_size
        self.total_num_heads = total_num_heads
        if total_num_kv_heads is None:
            total_num_kv_heads = total_num_heads
        self.total_num_kv_heads = total_num_kv_heads
        # Divide the weight matrix along the last dimension.
        tp_size = get_tensor_model_parallel_world_size()
        self.num_heads = divide(self.total_num_heads, tp_size)
        if tp_size >= self.total_num_kv_heads:
            self.num_kv_heads = 1
            self.num_kv_head_replicas = divide(tp_size,
                                               self.total_num_kv_heads)
        else:
            self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
            self.num_kv_head_replicas = 1
        input_size = self.hidden_size
        output_size = (self.num_heads +
                       2 * self.num_kv_heads) * tp_size * self.head_size
1049
1050
1051
1052
        self.output_sizes = [
            self.num_heads * self.head_size * tp_size,  # q_proj
            self.num_kv_heads * self.head_size * tp_size,  # k_proj
            self.num_kv_heads * self.head_size * tp_size,  # v_proj 
James Fleming's avatar
James Fleming committed
1053
        ]
gaoqiong's avatar
gaoqiong committed
1054

1055
1056
1057
1058
1059
1060
        super().__init__(input_size=input_size,
                         output_size=output_size,
                         bias=bias,
                         gather_output=False,
                         skip_bias_add=skip_bias_add,
                         params_dtype=params_dtype,
1061
                         quant_config=quant_config,
1062
1063
                         prefix=prefix,
                         return_bias=return_bias)
1064

1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
    def _get_shard_offset_mapping(self, loaded_shard_id: str):
        shard_offset_mapping = {
            "q": 0,
            "k": self.num_heads * self.head_size,
            "v": (self.num_heads + self.num_kv_heads) * self.head_size,
            "total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size
        }
        return shard_offset_mapping.get(loaded_shard_id)

    def _get_shard_size_mapping(self, loaded_shard_id: str):
        shard_size_mapping = {
            "q": self.num_heads * self.head_size,
            "k": self.num_kv_heads * self.head_size,
            "v": self.num_kv_heads * self.head_size,
        }
        return shard_size_mapping.get(loaded_shard_id)

    def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
                                           loaded_weight: torch.Tensor):
        """
        Handle special case for models where QKV layers are already 
        fused on disk. In this case, we have no shard id. This function
        determmines the shard id by splitting these layers and then calls
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """
        shard_offsets = [
            # (shard_id, shard_offset, shard_size)
            ("q", 0, self.total_num_heads * self.head_size),
            ("k", self.total_num_heads * self.head_size,
             self.total_num_kv_heads * self.head_size),
            ("v",
             (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
             self.total_num_kv_heads * self.head_size),
        ]

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
1107
1108
            if isinstance(param, (PackedColumnParameter, PackedvLLMParameter
                                  )) and param.packed_dim == param.output_dim:
1109
1110
                shard_size, shard_offset = \
                    param.adjust_shard_indexes_for_packing(
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
                    shard_size=shard_size, shard_offset=shard_offset)

            loaded_weight_shard = loaded_weight.narrow(param.output_dim,
                                                       shard_offset,
                                                       shard_size)
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

    def weight_loader_v2(self,
                         param: BasevLLMParameter,
                         loaded_weight: torch.Tensor,
                         loaded_shard_id: Optional[str] = None):
        if loaded_shard_id is None:  # special case for certain models
1123
            if isinstance(param, PerTensorScaleParameter):
1124
                param.load_qkv_weight(loaded_weight=loaded_weight, shard_id=0)
1125
                return
1126
1127
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
                param.load_qkv_weight(loaded_weight=loaded_weight)
1128
                return
1129
            # TODO: @dsikka - move to parameter.py
1130
1131
1132
1133
1134
1135
1136
1137
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id in ["q", "k", "v"]

        shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
        shard_size = self._get_shard_size_mapping(loaded_shard_id)

1138
1139
1140
1141
1142
1143
1144
1145
1146
        # Note(simon): This is needed for Qwen3's fp8 quantization.
        if isinstance(param, BlockQuantScaleParameter):
            assert self.quant_method is not None
            assert hasattr(self.quant_method, "quant_config")
            weight_block_size = self.quant_method.quant_config.weight_block_size
            block_n, _ = weight_block_size[0], weight_block_size[1]
            shard_offset = (shard_offset + block_n - 1) // block_n
            shard_size = (shard_size + block_n - 1) // block_n

1147
1148
1149
1150
1151
1152
        param.load_qkv_weight(loaded_weight=loaded_weight,
                              num_heads=self.num_kv_head_replicas,
                              shard_id=loaded_shard_id,
                              shard_offset=shard_offset,
                              shard_size=shard_size)

1153
1154
1155
1156
    def weight_loader(self,
                      param: Parameter,
                      loaded_weight: torch.Tensor,
                      loaded_shard_id: Optional[str] = None):
1157
1158
1159
1160
1161

        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
1162
        if is_gguf_weight_type:
1163
            idx_map = {"q": 0, "k": 1, "v": 2}
1164
1165
1166
1167
1168
1169
1170
1171
            if loaded_shard_id is not None:
                param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
                param.shard_weight_type = {
                    k: loaded_weight.item()
                    for k in idx_map
                }
1172
1173
            return

1174
1175
1176
        if is_gguf_weight:
            tp_size = get_tensor_model_parallel_world_size()
            tp_rank = get_tensor_model_parallel_rank()
1177

1178
1179
1180
1181
            output_dim = getattr(param, "output_dim", None)
            shard_size = loaded_weight.size(output_dim) // tp_size
            start_idx = tp_rank * shard_size

1182
1183
1184
1185
1186
1187
1188
            if loaded_shard_id is not None:
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
1189

1190
1191
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
1192
        # Special case for AQLM codebooks.
James Fleming's avatar
James Fleming committed
1193
        is_metadata = getattr(param, "is_metadata", False)
1194

1195
1196
        # Special case for per-tensor scales in fused case.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
1197
        is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
1198

1199
        if loaded_shard_id is None:
1200
1201
            # Loaded weight is already fused on disk (qkv).
            # (e.g., Phi-3's qkv_proj).
1202
            if output_dim is None:
1203
                if needs_scalar_to_array:
1204
1205
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
                        param_data, loaded_weight, 0)
1206

1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            shard_offsets = [
                # (shard_id, shard_offset, shard_size)
                ("q", 0, self.total_num_heads * self.head_size),
                ("k", self.total_num_heads * self.head_size,
                 self.total_num_kv_heads * self.head_size),
                ("v", (self.total_num_heads + self.total_num_kv_heads) *
                 self.head_size, self.total_num_kv_heads * self.head_size),
            ]
1218
1219
1220
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)

1221
1222
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
1223
                # Special case for Quantized Weights.
1224
1225
1226
1227
1228
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
                    shard_size = shard_size // param.pack_factor
                    shard_offset = shard_offset // param.pack_factor
1229

1230
                    # Special case for Marlin.
1231
1232
1233
                    shard_size, shard_offset = adjust_marlin_shard(
                        param, shard_size, shard_offset)

1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
                if use_bitsandbytes_4bit:
                    orig_qkv_offsets = {
                        "q": (0, self.total_num_heads * self.head_size),
                        "k": (self.total_num_heads * self.head_size,
                              self.total_num_kv_heads * self.head_size),
                        "v":
                        ((self.total_num_heads + self.total_num_kv_heads) *
                         self.head_size,
                         self.total_num_kv_heads * self.head_size),
                        "total":
                        ((self.total_num_heads + 2 * self.total_num_kv_heads) *
                         self.head_size, 0)
                    }

                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
                        param, orig_qkv_offsets, shard_id)

1251
1252
1253
1254
1255
1256
1257
                loaded_weight_shard = loaded_weight.narrow(
                    output_dim, shard_offset, shard_size)
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        tp_rank = get_tensor_model_parallel_rank()
        assert loaded_shard_id in ["q", "k", "v"]
1258
1259

        # If output dim is defined, use the default loading process.
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
        if output_dim is not None:
            if loaded_shard_id == "q":
                shard_offset = 0
                shard_size = self.num_heads * self.head_size
            elif loaded_shard_id == "k":
                shard_offset = self.num_heads * self.head_size
                shard_size = self.num_kv_heads * self.head_size
            elif loaded_shard_id == "v":
                shard_offset = (self.num_heads +
                                self.num_kv_heads) * self.head_size
                shard_size = self.num_kv_heads * self.head_size
1271
            # Special case for Quantized Weights.
1272
1273
1274
1275
1276
1277
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
                shard_size = shard_size // param.pack_factor
                shard_offset = shard_offset // param.pack_factor
1278

1279
                # Special case for Marlin.
1280
1281
1282
                shard_size, shard_offset = adjust_marlin_shard(
                    param, shard_size, shard_offset)

1283
1284
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)
1285
1286
1287
1288
1289
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

1290
            if use_bitsandbytes_4bit:
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
                orig_qkv_offsets = {
                    "q": (0, self.num_heads * self.head_size),
                    "k": (self.num_heads * self.head_size,
                          self.num_kv_heads * self.head_size),
                    "v":
                    ((self.num_heads + self.num_kv_heads) * self.head_size,
                     self.num_kv_heads * self.head_size),
                    "total":
                    ((self.num_heads + 2 * self.num_kv_heads) * self.head_size,
                     0)
                }
1302
                shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
1303
                    param, orig_qkv_offsets, loaded_shard_id)
gaoqiong's avatar
gaoqiong committed
1304

1305
1306
1307
1308
1309
1310
1311
            if not envs.VLLM_USE_NN or len(param_data.shape)==1 or is_quantization:
                param_data = param_data.narrow(output_dim, shard_offset,
                                               shard_size)
            else:
                param_data = param_data.narrow(int(not(output_dim)), shard_offset,
                                               shard_size)
                
zhuwenwen's avatar
zhuwenwen committed
1312
            if loaded_shard_id == "q":
1313
1314
1315
                shard_id = tp_rank
            else:
                shard_id = tp_rank // self.num_kv_head_replicas
1316
            start_idx = shard_id * shard_size
1317

1318
            if not is_sharded_weight:
1319
1320
1321
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)

1322
        # Special case for for AQLM codebooks.
James Fleming's avatar
James Fleming committed
1323
1324
1325
1326
1327
1328
        elif is_metadata:
            # metadata indicates fixed size concatenated along dim 0
            shard_size = loaded_weight.shape[0]
            shard_index = ["q", "k", "v"].index(loaded_shard_id)
            param_data = param_data.narrow(0, shard_index * shard_size,
                                           shard_size)
1329
1330
1331
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
1332
                param_data, loaded_weight, loaded_shard_id)
1333
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
1334
1335
1336
1337
1338
1339
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "QKVParallelLinear, assume the weight is the same "
                    "for all partitions.")
gaoqiong's avatar
gaoqiong committed
1340

1341
1342
1343
        if envs.VLLM_USE_NN and not is_quantization:
            loaded_weight = loaded_weight.t()
            
gaoqiong's avatar
gaoqiong committed
1344
1345
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)
1346
1347


1348
class RowParallelLinear(LinearBase):
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
    """Linear layer with row parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its first dimension and X along its second dimension as:
               -   -
              | A_1 |
              | .   |
          A = | .   |        X = [X_1, ..., X_p]
              | .   |
              | A_p |
               -   -
    Arguments:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias. Note that bias is not parallelized.
        input_is_parallel: If true, we assume that the input is already
                           split across the GPUs and we do not split
                           again.
        skip_bias_add: This was added to enable performance optimization where
                       bias can be fused with other element-wise operations.
                       We skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
1371
1372
1373
        reduce_results: If true, call all-reduce on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y = X_iA_i
1374
        quant_config: Quantization configure.
1375
1376
1377
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.down_proj)
        return_bias: If true, return bias together with outputs in forward pass.
1378
1379
    """

1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        input_is_parallel: bool = True,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        reduce_results: bool = True,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        *,
        return_bias: bool = True,
    ):
1394
1395
1396
1397
1398
1399
1400
        # Divide the weight matrix along the first dimension.
        self.tp_rank = get_tensor_model_parallel_rank()
        self.tp_size = get_tensor_model_parallel_world_size()
        self.input_size_per_partition = divide(input_size, self.tp_size)
        self.output_size_per_partition = output_size
        self.output_partition_sizes = [output_size]

1401
1402
1403
1404
1405
1406
1407
        super().__init__(input_size,
                         output_size,
                         skip_bias_add,
                         params_dtype,
                         quant_config,
                         prefix,
                         return_bias=return_bias)
1408

1409
1410
1411
        self.input_is_parallel = input_is_parallel
        self.reduce_results = reduce_results

1412
        assert self.quant_method is not None
1413
1414
1415
        self.quant_method.create_weights(
            layer=self,
            input_size_per_partition=self.input_size_per_partition,
1416
            output_partition_sizes=self.output_partition_sizes,
1417
1418
1419
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
1420
1421
            weight_loader=(
                self.weight_loader_v2 if self.quant_method.__class__.__name__
1422
                in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
1423
1424
1425
1426
1427
1428
        if not reduce_results and (bias and not skip_bias_add):
            raise ValueError("When not reduce the results, adding bias to the "
                             "results can lead to incorrect results")

        if bias:
            self.bias = Parameter(
1429
                torch.empty(self.output_size, dtype=params_dtype))
1430
1431
1432
1433
1434
1435
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
        else:
            self.register_parameter("bias", None)
1436
        from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce
1437
        self.tbo_all_reduce = tbo_all_reduce
1438
1439
1440

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        tp_rank = get_tensor_model_parallel_rank()
1441
        tp_size = get_tensor_model_parallel_world_size()
1442
        input_dim = getattr(param, "input_dim", None)
1443
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1444
1445
1446
1447
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
1448
1449
1450
1451
1452
1453
1454
1455
1456

        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
1457
1458
1459
1460
            weight_shape = list(loaded_weight.shape)
            if input_dim:
                weight_shape[input_dim] = weight_shape[input_dim] // tp_size
            param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
1461
1462
            
        is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
1463

1464
        param_data = param.data
1465
        if input_dim is not None and not is_sharded_weight:
1466
1467
1468
1469
            if not envs.VLLM_USE_NN or is_quantization:
                shard_size = param_data.shape[input_dim]
            else:
                shard_size = param_data.shape[int(not(input_dim))]
1470
1471
1472
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(input_dim, start_idx,
                                                 shard_size)
1473

1474
1475
1476
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
1477
1478
            loaded_weight = loaded_weight.reshape(1)

1479
1480
1481
        if envs.VLLM_USE_NN and not is_quantization:
            loaded_weight = loaded_weight.t()
            
1482
1483
1484
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

1485
1486
    def weight_loader_v2(self, param: BasevLLMParameter,
                         loaded_weight: torch.Tensor):
1487
1488
1489
1490
1491
1492
1493

        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)

1494
1495
        param.load_row_parallel_weight(loaded_weight=loaded_weight)

1496
    def forward(
1497
1498
        self, input_,
        use_fused_silu_mul_quant: Optional[bool] = False
1499
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
1500
1501
1502
1503
1504
1505
1506
1507
1508
        if self.input_is_parallel:
            input_parallel = input_
        else:
            tp_rank = get_tensor_model_parallel_rank()
            splitted_input = split_tensor_along_last_dim(
                input_, num_partitions=self.tp_size)
            input_parallel = splitted_input[tp_rank].contiguous()

        # Matrix multiply.
1509
        assert self.quant_method is not None
1510
1511
1512
        # Only fuse bias add into GEMM for rank 0 (this ensures that
        # bias will not get added more than once in TP>1 case)
        bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
        if use_fused_silu_mul_quant:
            xq, xs = lm_fuse_silu_mul_quant(input_parallel)
            
            silu_quant_args = [xq, xs]
            output_parallel = self.quant_method.apply(self,
                                                      input_parallel,
                                                      bias=bias_,
                                                      silu_quant_args=silu_quant_args)
        else:
            output_parallel = self.quant_method.apply(self,
                                                      input_parallel,
                                                      bias=bias_)
1525
        if self.reduce_results and self.tp_size > 1:
1526
            if envs.VLLM_ENABLE_TBO:
1527
1528
1529
                output = self.tbo_all_reduce(output_parallel)
            else:
                output = tensor_model_parallel_all_reduce(output_parallel)
1530
        else:
1531
1532
1533
            output = output_parallel

        output_bias = self.bias if self.skip_bias_add else None
1534

1535
1536
        if not self.return_bias:
            return output
1537
        return output, output_bias
1538
1539
1540
1541
1542
1543
1544
1545

    def extra_repr(self) -> str:
        s = f"input_features={self.input_size_per_partition}"
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={self.tp_size}"
        s += f", reduce_results={self.reduce_results}"
        return s
1546
1547


1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
class QKVCrossParallelLinear(LinearBase):
    """Linear layers for efficient cross-attention's QKV transformation.

    Args:
        hidden_size: input hidden state size of the transformer.
        head_size: size of each attention head.
        total_num_heads: total number of attention query heads.
        total_num_kv_heads: total number of attention key/value heads. If
                            None, assume total_num_kv_heads = total_num_heads.
        bias: If true, add bias.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
        quant_config: Quantization configure.
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
    """
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576

    def __init__(self,
                 hidden_size: int,
                 head_size: int,
                 total_num_heads: int,
                 total_num_kv_heads: Optional[int] = None,
                 bias: bool = True,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
        # input_size and output_size are not used, just for alignment
        input_size = hidden_size
        output_size = (total_num_heads + (total_num_kv_heads or 0)) * head_size
        super().__init__(input_size=input_size,
                         output_size=output_size,
                         skip_bias_add=skip_bias_add,
                         params_dtype=params_dtype,
                         quant_config=quant_config,
                         prefix=prefix)

        self.quant_config = quant_config

1589
        # Empty placeholders for loading as a single module.
1590
1591
1592
1593
1594
1595
1596
1597
1598
        placeholder_size = 0
        assert self.quant_method is not None
        self.quant_method.create_weights(self,
                                         placeholder_size, [placeholder_size],
                                         placeholder_size,
                                         placeholder_size,
                                         self.params_dtype,
                                         weight_loader=self.weight_loader)

1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
        # Use a dictionary to avoid submodules parameters auto-registration:
        # drop-in replacement for a `QKVParallelLinear` module.
        self.proj = dict()
        self.proj["q_proj_decoder"] = ColumnParallelLinear(
            input_size=hidden_size,
            output_size=total_num_heads * head_size,
            bias=bias,
            quant_config=quant_config,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            prefix=f"{prefix}.q_proj_decoder")

        self.proj["kv_proj_encoder"] = QKVParallelLinear(
            hidden_size=hidden_size,
            head_size=head_size,
            total_num_heads=0,
            total_num_kv_heads=total_num_kv_heads,
            bias=bias,
            quant_config=quant_config,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            prefix=f"{prefix}.kv_proj_encoder")

        # `kv_proj_encoder.num_kv_heads` accounts for sharding with tp>1.
1623
        self.q_size = self.q_proj_decoder.output_size_per_partition
1624
1625
1626
1627
1628
        self.kv_size = self.kv_proj_encoder.num_kv_heads * head_size

        if bias:
            self.bias = torch.nn.Parameter()
            set_weight_attrs(self.bias, {
1629
1630
                "output_dim": 0,
                "weight_loader": self.weight_loader,
1631
            })
1632
1633
        else:
            self.bias = None
1634

1635
1636
1637
1638
1639
    def process_weights_after_loading(self):
        for layer in self.proj.values():
            if self.quant_method is not None:
                self.quant_method.process_weights_after_loading(layer)

1640
    @property
1641
1642
1643
    def q_proj_decoder(self) -> ColumnParallelLinear:
        layer = self.proj["q_proj_decoder"]
        for name, param in self.named_parameters():
1644
1645
1646
1647
1648
            target_param = getattr(layer, name, None)
            if target_param is not None:
                self.sync_weight_attrs(param,
                                       target_param,
                                       mode="q_proj_decoder")
1649
        return layer
1650
1651

    @property
1652
1653
1654
    def kv_proj_encoder(self) -> QKVParallelLinear:
        layer = self.proj["kv_proj_encoder"]
        for name, param in self.named_parameters():
1655
1656
1657
1658
1659
            target_param = getattr(layer, name, None)
            if target_param is not None:
                self.sync_weight_attrs(param,
                                       target_param,
                                       mode="kv_proj_encoder")
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
        return layer

    def sync_weight_attrs(
        self,
        src_param: nn.Parameter,
        tgt_param: nn.Parameter,
        mode: Literal["q_proj_decoder", "kv_proj_encoder"],
    ):
        missing_attrs_dict = {
            k: getattr(src_param, k)
1670
1671
            for k in (set(vars(src_param).keys()) -
                      set(vars(tgt_param).keys()))
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
        }
        # TODO(Isotr0py): handle bitsandbytes 8bit
        use_bitsandbytes_4bit = getattr(src_param, "use_bitsandbytes_4bit",
                                        False)
        if (missing_attrs_dict and use_bitsandbytes_4bit):
            q_proj_attrs, kv_proj_attrs = left_shift_bitsandbytes_4bit_shard(
                missing_attrs_dict)
            if mode == "q_proj_decoder":
                set_weight_attrs(tgt_param, q_proj_attrs)
            elif mode == "kv_proj_encoder":
                set_weight_attrs(tgt_param, kv_proj_attrs)
        else:
            set_weight_attrs(tgt_param, missing_attrs_dict)
1685

1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
    def _is_same_param(
        self,
        src_param: torch.nn.Parameter,
        map_param: torch.nn.Parameter,
    ) -> bool:
        """Check if two parameters are exactly pointing to same things."""
        # ignore weight_loader because it's always different
        key_to_ignore = ["weight_loader", "_weight_loader"]
        has_same_type_name = type(src_param) is type(map_param)
        src_param_attrs = {
            k: v
            for k, v in src_param.__dict__.items() if k not in key_to_ignore
        }
        map_param_attrs = {
            k: v
            for k, v in map_param.__dict__.items() if k not in key_to_ignore
        }
        has_same_attrs = src_param_attrs == map_param_attrs
        return has_same_type_name and has_same_attrs

    def select_proj_params(
        self,
        layer: nn.Module,
        param: nn.Parameter,
    ) -> nn.Parameter:
        """
        Given the placeholder param, 
        return the corresponding param in the proj layers.
        """
        target_param_list = [
            v for _, v in layer.named_parameters()
            if self._is_same_param(param, v)
        ]
        assert len(target_param_list) == 1
        target_param = target_param_list[0]
        return target_param

    def forward(  # type: ignore[override]
        self,
        decoder_hidden_states: torch.Tensor,
        encoder_hidden_states: torch.Tensor,
    ) -> tuple[torch.Tensor, ...]:
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
        q, _ = self.q_proj_decoder(decoder_hidden_states)
        if encoder_hidden_states is None:
            # Encoder KV already cached.
            k = None
            v = None
        else:
            # Prefill phase, encoder KV cached here.
            kv_enc, _ = self.kv_proj_encoder(encoder_hidden_states)
            # Split kv in half
            k, v = kv_enc.split(self.kv_size, dim=-1)
        return q, k, v

1740
1741
1742
1743
1744
1745
1746
1747
    def weight_loader(self,
                      param: torch.nn.Parameter,
                      loaded_weight: torch.Tensor,
                      loaded_shard_id: Optional[str] = None):
        layer = (self.q_proj_decoder
                 if loaded_shard_id == "q" else self.kv_proj_encoder)
        target_param = self.select_proj_params(layer, param)
        shard_id_args = (loaded_shard_id, ) if loaded_shard_id != "q" else ()
1748
1749
1750
1751
        if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED:
            layer.weight_loader_v2(target_param, loaded_weight, *shard_id_args)
        else:
            layer.weight_loader(target_param, loaded_weight, *shard_id_args)
1752
1753
1754

    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
1755
        s += f", q_size={self.q_size}"
1756
1757
1758
1759
        s += f", kv_size={self.kv_size}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={get_tensor_model_parallel_world_size()}"
        s += ", gather_output=False"
1760
        return s