linear.py 76 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
from abc import abstractmethod
6
from typing import Any, Literal, Optional, Union
7
import vllm.envs as envs
8
import torch
9
import torch.nn as nn
10
from torch.nn.parameter import Parameter, UninitializedParameter
11

12
13
14
15
16
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size,
                              split_tensor_along_last_dim,
                              tensor_model_parallel_all_gather,
                              tensor_model_parallel_all_reduce)
17
from vllm.logger import init_logger
18
from vllm.model_executor.custom_op import CustomOp
19
20
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig, QuantizeMethodBase)
21
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
22
# yapf: disable
23
from vllm.model_executor.parameter import (BasevLLMParameter,
24
                                           BlockQuantScaleParameter,
25
                                           PackedColumnParameter,
26
                                           PackedvLLMParameter,
27
28
                                           PerTensorScaleParameter,
                                           RowvLLMParameter)
29
# yapf: enable
30
from vllm.model_executor.utils import set_weight_attrs
31
from vllm.platforms import current_platform
gaoqiong's avatar
gaoqiong committed
32

zhuwenwen's avatar
zhuwenwen committed
33
import os
34
from vllm.model_executor.utils import gemm_bank_conf
35

36
37
38
39
40
if envs.USE_FUSED_RMS_QUANT:
    try:
        from lmslim.quantize.quant_ops import lm_faster_rmsquant
    except Exception as e:
        print(f"Error: Import fused rmsquant error: {e}") 
41
42
43
44
45
46
47
        
if envs.USE_FUSED_SILU_MUL_QUANT:        
    try:
        # from lightop import fuse_silu_mul_quant
        from lmslim.quantize.quant_ops import lm_fuse_silu_mul_quant
    except Exception as e:
        print(f"Error: Import fused silu_mul_qunat error: {e}")
48

49
50
logger = init_logger(__name__)

51
WEIGHT_LOADER_V2_SUPPORTED = [
52
    "CompressedTensorsLinearMethod",
53
    "CompressedTensorsLinearTransformMethod",
54
55
    "BitBLASLinearMethod",
    "GPTQBitBLASLinearMethod",
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
    "AWQMarlinLinearMethod",
    "AWQLinearMethod",
    "GPTQMarlinLinearMethod",
    "Fp8LinearMethod",
    "MarlinLinearMethod",
    "GPTQMarlin24LinearMethod",
    "TPUInt8LinearMethod",
    "GPTQLinearMethod",
    "FBGEMMFp8LinearMethod",
    "ModelOptFp8LinearMethod",
    "IPEXAWQLinearMethod",
    "IPEXGPTQLinearMethod",
    "HQQMarlinMethod",
    "QuarkLinearMethod",
    "ModelOptNvFp4LinearMethod",
71
    "PetitNvFp4LinearMethod",
zhuwenwen's avatar
zhuwenwen committed
72
    "BlockInt8LinearMethod",
73
]
74

75

76
77
78
79
80
81
82
83
84
def adjust_bitblas_shard(param, shard_size, shard_offset):
    bitblas_tile_size = getattr(param, "bitblas_tile_size", None)
    if bitblas_tile_size is not None:
        return (shard_size // bitblas_tile_size,
                shard_offset // bitblas_tile_size)

    return shard_size, shard_offset


85
86
87
88
89
90
91
92
def adjust_marlin_shard(param, shard_size, shard_offset):
    marlin_tile_size = getattr(param, "marlin_tile_size", None)
    if marlin_tile_size is None:
        return shard_size, shard_offset

    return shard_size * marlin_tile_size, shard_offset * marlin_tile_size


93
def adjust_bitsandbytes_4bit_shard(param: Parameter,
94
95
                                   shard_offsets: dict[str, tuple[int, int]],
                                   loaded_shard_id: str) -> tuple[int, int]:
96
97
    """Adjust the quantization offsets and sizes for BitsAndBytes sharding."""

98
99
    total, _ = shard_offsets["total"]
    orig_offset, orig_size = shard_offsets[loaded_shard_id]
100
101
102
103
104
105
106
107

    quantized_total = param.data.shape[0]
    quantized_offset = orig_offset * quantized_total // total
    quantized_size = orig_size * quantized_total // total

    return quantized_size, quantized_offset


108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
    """For fused modules (QKV and MLP) we have an array of length
    N that holds 1 scale for each "logical" matrix. So the param
    is an array of length N. The loaded_weight corresponds to 
    one of the shards on disk. Here, we slice the param based on 
    the shard_id for loading.
    """
    qkv_idxs = {"q": 0, "k": 1, "v": 2}

    if isinstance(shard_id, str):
        shard_id = qkv_idxs[shard_id]
    elif not isinstance(shard_id, int):
        raise ValueError(f"Unknown Shard Id {shard_id}")

    # AutoFP8 scales do not have a shape
    # compressed-tensors scales do have a shape
    if len(loaded_weight.shape) != 0:
        assert loaded_weight.shape[0] == 1
        loaded_weight = loaded_weight[0]

128
129
130
131
    if envs.VLLM_USE_NN:
        return param[shard_id], loaded_weight.t()
    else:
        return param[shard_id], loaded_weight
132
133


134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
# TODO(Isotr0py): We might need a more flexible structure to handle
# bitsandbytes shard offsets.
def left_shift_bitsandbytes_4bit_shard(bnb_weight_attrs: dict[str, Any]):
    """
    Separate the BitsAndBytes 4-bit shard.

    For example, given bnb weight attributes as below:
    {
        'bnb_shard_offsets': array([0, 4, 8, 16]), 
        'bnb_quant_state': {0: ..., 1: ..., 2: ...},
    }

    The function will return:
    {
        'bnb_shard_offsets': array([0, 4]), 
        'bnb_quant_state': {0: ...},
    }
    and
    {
        'bnb_shard_offsets': array([0, 4, 12]),
        'bnb_quant_state': {0: ..., 1: ...},
    }
    """
    shard_offsets = bnb_weight_attrs["bnb_shard_offsets"]
    offset_l = shard_offsets[:2]
    offset_r = shard_offsets[1:] - shard_offsets[1]
    quant_state_l = {0: bnb_weight_attrs["bnb_quant_state"][0]}
    quant_state_r = {
        i - 1: bnb_weight_attrs["bnb_quant_state"][i]
        for i in range(1,
                       len(shard_offsets) - 1)
    }
    left = dict(bnb_shard_offsets=offset_l, bnb_quant_state=quant_state_l)
    right = dict(bnb_shard_offsets=offset_r, bnb_quant_state=quant_state_r)
    return left, right


171
class LinearMethodBase(QuantizeMethodBase):
172
173
174
    """Base class for different (maybe quantized) linear methods."""

    @abstractmethod
175
176
    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
177
                       output_partition_sizes: list[int], input_size: int,
178
179
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
180
181
        """Create weights for a linear layer. 
           The weights will be set as attributes of the layer.
182

183
184
185
186
187
188
189
190
191
192
        Args:
            layer: The layer that is using the LinearMethodBase factory.
            input_size_per_partition: Size of the weight input dim on rank X.
            output_partition_sizes: Sizes of the output dim of each logical 
                weight on rank X. E.g., output_partition_sizes for QKVLinear
                is a list contains the width of Wq, Wk, Wv on rank X.
            input_size: Size of the input dim of the weight across all ranks.
            output_size: Size of the output dim of the weight across all ranks.
            params_dtype: Datatype of the parameters.
        """
193
194
195
        raise NotImplementedError

    @abstractmethod
196
197
198
199
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
200
201
        """Apply the weights in layer to the input tensor.
        Expects create_weights to have been called before on the layer."""
202
203
204
205
        raise NotImplementedError


class UnquantizedLinearMethod(LinearMethodBase):
206
    """Linear method without quantization."""
207
208
    
    def __init__(self):
zhuwenwen's avatar
zhuwenwen committed
209
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
210
        self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
gaoqiong's avatar
gaoqiong committed
211
        
212
213
    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
214
                       output_partition_sizes: list[int], input_size: int,
215
216
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
217
218
219
        if envs.VLLM_USE_NN:
            weight = Parameter(torch.empty(input_size_per_partition,
                                       sum(output_partition_sizes),
220
221
                                       dtype=params_dtype),
                           requires_grad=False)
222
223
224
225
226
        else:
            weight = Parameter(torch.empty(sum(output_partition_sizes),
                                           input_size_per_partition,
                                           dtype=params_dtype),
                           requires_grad=False)
227
        set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
228
229
        layer.register_parameter("weight", weight)
        set_weight_attrs(weight, extra_weight_attrs)
230

231
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
232
233
234
235
        if current_platform.is_cpu():
            from vllm.model_executor.layers.utils import (
                dispatch_cpu_unquantized_gemm)
            dispatch_cpu_unquantized_gemm(layer, remove_weight=True)
236

237
238
239
240
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
zhuwenwen's avatar
zhuwenwen committed
241
        if self.use_llama_nn:
242
243
            if gemm_bank_conf(layer.weight.shape[1] - 32) and os.environ['GEMM_PAD'] == '1':
                layer.weight = layer.weight[:,:-32]
244
                
zhuwenwen's avatar
zhuwenwen committed
245
            if bias is not None:
zhuwenwen's avatar
zhuwenwen committed
246
                if len(x.shape) == 2: 
247
                    return torch.addmm(bias, x, layer.weight)
zhuwenwen's avatar
zhuwenwen committed
248
                else:
249
                    return torch.matmul(x, layer.weight) + bias
zhuwenwen's avatar
zhuwenwen committed
250
            else:
251
                return torch.matmul(x, layer.weight)
zhuwenwen's avatar
zhuwenwen committed
252
        else:
253
254
255
256
            if envs.VLLM_USE_NN and x.shape[-1] == layer.weight.shape[0]:
                return dispatch_unquantized_gemm()(x, layer.weight.t(), bias)
            else:
                return dispatch_unquantized_gemm()(x, layer.weight, bias)
257

258

259
class LinearBase(CustomOp):
260
    """Base linear layer.
261
262
263
264
265
266

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
267
        quant_config: Quantization configure.
268
        prefix: Prefix for parameter names.
269
        return_bias: If true, return bias together with outputs in forward pass.
270
        disable_tp: If true, tensor parallelism will be disabled for this layer.
271
272
273
274
275
276
277
278
    """

    def __init__(
        self,
        input_size: int,
        output_size: int,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
279
        quant_config: Optional[QuantizationConfig] = None,
280
        prefix: str = "",
281
282
        *,
        return_bias: bool = True,
283
        disable_tp: bool = False,
284
285
286
287
288
289
290
291
292
293
    ):
        super().__init__()

        # Keep input parameters
        self.input_size = input_size
        self.output_size = output_size
        self.skip_bias_add = skip_bias_add
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype
294
295
        self.quant_config = quant_config
        self.prefix = prefix
296
        if quant_config is None:
297
298
            self.quant_method: Optional[
                QuantizeMethodBase] = UnquantizedLinearMethod()
299
        else:
300
301
            self.quant_method = quant_config.get_quant_method(self,
                                                              prefix=prefix)
302
        self.return_bias = return_bias
303
304
305
306
307
308
        self.disable_tp = disable_tp
        self.tp_rank = (get_tensor_model_parallel_rank()
                        if not disable_tp else 0)
        self.tp_size = (get_tensor_model_parallel_world_size()
                        if not disable_tp else 1)

309
    def update_param_tp_status(self):
310
311
312
313
        for param in self.parameters():
            if isinstance(param, BasevLLMParameter):
                param.tp_rank = self.tp_rank
                param.tp_size = self.tp_size
314
315


316
@CustomOp.register("replicated_linear")
317
318
319
320
321
322
323
324
325
326
class ReplicatedLinear(LinearBase):
    """Replicated linear layer.

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        bias: If true, add bias.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
        quant_config: Quantization configure.
327
328
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
329
        return_bias: If true, return bias together with outputs in forward pass.
330
        disable_tp: Take no effect for replicated linear layers.
331
332
    """

333
334
335
336
337
338
339
340
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
341
        eps: Optional[float] = 1e-6,
342
343
344
        prefix: str = "",
        *,
        return_bias: bool = True,
345
        disable_tp: bool = False,
346
    ):
347
348
349
350
351
        super().__init__(input_size,
                         output_size,
                         skip_bias_add,
                         params_dtype,
                         quant_config,
352
                         prefix=prefix,
353
354
                         return_bias=return_bias,
                         disable_tp=disable_tp)
355
        self.eps = eps
356

357
358
        # All the linear layer supports quant method.
        assert self.quant_method is not None
359
        self.quant_method.create_weights(self,
360
                                         self.input_size, [self.output_size],
361
362
363
                                         self.input_size,
                                         self.output_size,
                                         self.params_dtype,
364
                                         weight_loader=self.weight_loader)
365

366
367
        if bias:
            self.bias = Parameter(
368
                torch.empty(self.output_size, dtype=self.params_dtype))
369
370
371
372
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
373
374
375
        else:
            self.register_parameter("bias", None)

376
377
378
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        # If the weight on disk does not have a shape, give it one
        # (such scales for AutoFp8).
379
380
381
382
383
384
385
386
387
388
389
        # Special case for GGUF

        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
            param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)

390
391
392
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)

393
394
395
396
        is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
        if envs.VLLM_USE_NN and not is_quantization:
            loaded_weight = loaded_weight.t()
            
397
398
399
        assert param.size() == loaded_weight.size(), (
            f"Tried to load weights of size {loaded_weight.size()}"
            f"to a parameter of size {param.size()}")
400
401
        param.data.copy_(loaded_weight)

402
    def forward(
403
404
405
406
407
408
        self, 
        input_: torch.Tensor,
        rms_weight: Optional[torch.Tensor] = None,
        residual: Optional[torch.Tensor] = None,
        quant_args: Optional[list] = None,
        update_hd: Optional[bool] = True
409
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
        if envs.USE_FUSED_RMS_QUANT and (rms_weight is not None or quant_args is not None):
            if quant_args is not None:
                input_quant_args = quant_args
            
                bias = self.bias if not self.skip_bias_add else None
                assert self.quant_method is not None
                output = self.quant_method.apply(self, input_, bias, input_quant_args)
                output_bias = self.bias if self.skip_bias_add else None
                if not self.return_bias:
                    return output
                return output, output_bias

            else:
                i_q, _scales = lm_faster_rmsquant(input=input_,
                                                  rms_weight=rms_weight,
                                                  epsilon=self.eps,
                                                  quant_dtype=torch.int8,
                                                  residual=residual,
                                                  update_input=update_hd
                                                  )
            
                new_residual = residual
                input_quant_args = [i_q, _scales]
                
                bias = self.bias if not self.skip_bias_add else None
                assert self.quant_method is not None
                output = self.quant_method.apply(self, input_, bias, input_quant_args)
                output_bias = self.bias if self.skip_bias_add else None
                if not self.return_bias:
                    return output
                return output, new_residual, output_bias, input_quant_args

        else:
            bias = self.bias if not self.skip_bias_add else None
            assert self.quant_method is not None
            output = self.quant_method.apply(self, input_, bias)
            output_bias = self.bias if self.skip_bias_add else None
            if not self.return_bias:
                return output
            return output, output_bias
450

451
452
453
454
455
456
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        return s

457

458
@CustomOp.register("column_parallel_linear")
459
class ColumnParallelLinear(LinearBase):
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
    """Linear layer with column parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its second dimension as A = [A_1, ..., A_p].

    Args:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y_i = XA_i
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
476
        quant_config: Quantization configure.
James Fleming's avatar
James Fleming committed
477
478
        output_sizes: list of output sizes packed into one output, like for QKV
                       the list would be size 3.
479
        prefix: The name of the layer in the state dict, including all parents
480
481
482
                        (e.g. model.layers.0.qkv_proj)
        return_bias: If true, return bias together with outputs in forward pass.
        disable_tp: If true, weights matrix won't be sharded through tp rank.
483
484
    """

485
486
487
488
489
490
491
492
493
494
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
        output_sizes: Optional[list[int]] = None,
495
        eps: Optional[float] = 1e-6,
496
497
498
        prefix: str = "",
        *,
        return_bias: bool = True,
499
        disable_tp: bool = False,
500
    ):
501
        # Divide the weight matrix along the last dimension.
502
503
504
505
        self.tp_rank = (get_tensor_model_parallel_rank()
                        if not disable_tp else 0)
        self.tp_size = (get_tensor_model_parallel_world_size()
                        if not disable_tp else 1)
506
507
        self.input_size_per_partition = input_size
        self.output_size_per_partition = divide(output_size, self.tp_size)
508
509
510
511
        self.output_partition_sizes = [self.output_size_per_partition]
        # If QKV or MergedColumn, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = [
512
                divide(output_size, self.tp_size)
513
514
515
                for output_size in self.output_sizes
            ]

516
517
518
519
520
521
        super().__init__(input_size,
                         output_size,
                         skip_bias_add,
                         params_dtype,
                         quant_config,
                         prefix,
522
523
                         return_bias=return_bias,
                         disable_tp=disable_tp)
524

525
        self.eps = eps
526
527
        self.gather_output = gather_output

James Fleming's avatar
James Fleming committed
528
529
        if output_sizes is None:
            output_sizes = [output_size]
530

531
        assert self.quant_method is not None
532
533
        self.quant_method.create_weights(
            layer=self,
534
            input_size_per_partition=self.input_size_per_partition,
535
536
537
538
            output_partition_sizes=self.output_partition_sizes,
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
539
540
            weight_loader=(
                self.weight_loader_v2 if self.quant_method.__class__.__name__
541
                in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
542
543
544
545
546
547
548
549
550
551
        if bias:
            self.bias = Parameter(
                torch.empty(self.output_size_per_partition,
                            dtype=params_dtype))
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
        else:
            self.register_parameter("bias", None)
552
        self.update_param_tp_status()
553

554
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
555

556
        output_dim = getattr(param, "output_dim", None)
557

558
559
560
561
562
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
563
        is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
564

565
566
567
568
569
570
571
572
        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
573
574
            final_shape = list(loaded_weight.shape)
            if output_dim is not None:
575
576
577
                assert final_shape[output_dim] % self.tp_size == 0
                final_shape[output_dim] = (final_shape[output_dim] //
                                           self.tp_size)
578
            param.materialize(final_shape, dtype=loaded_weight.dtype)
579

580
        param_data = param.data
581
        if output_dim is not None and not is_sharded_weight:
582
583
584
585
            if not envs.VLLM_USE_NN or len(param_data.shape)==1 or is_quantization:
                shard_size = param_data.shape[output_dim] 
            else:
                shard_size = param_data.shape[int(not(output_dim))]
586
            start_idx = self.tp_rank * shard_size
587

588
589
            loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                 shard_size)
590
591
592
593
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)
594

595
596
597
        if envs.VLLM_USE_NN and not is_quantization:
            loaded_weight = loaded_weight.t()
            
598
599
600
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

601
602
    def weight_loader_v2(self, param: BasevLLMParameter,
                         loaded_weight: torch.Tensor):
603
604
605
606
607
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)
608
609
        param.load_column_parallel_weight(loaded_weight=loaded_weight)

610
    def forward(
611
612
613
614
        self, input_,
        rms_weight: Optional[torch.Tensor] = None,
        residual: Optional[torch.Tensor] = None,
        update_hd: Optional[bool] = True
615
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
        if envs.USE_FUSED_RMS_QUANT and rms_weight is not None:
            input_quant_args = None
            assert rms_weight is not None 
            i_q, _scales = lm_faster_rmsquant(input=input_,
                                        rms_weight=rms_weight,
                                        epsilon=self.eps,
                                        quant_dtype=torch.int8,
                                        residual=residual,
                                        update_input=update_hd)
            new_residual = residual
            input_quant_args = [i_q, _scales]
        
            bias = self.bias if not self.skip_bias_add else None
            
            assert self.quant_method is not None
            output_parallel = self.quant_method.apply(self, input_, bias, input_quant_args)
            if self.gather_output and self.tp_size > 1:
                output = tensor_model_parallel_all_gather(output_parallel)
            else:
                output = output_parallel
            output_bias = self.bias if self.skip_bias_add else None
            if not self.return_bias:
                return output
            return output, new_residual, output_bias
640
        else:
641
642
643
644
645
646
647
648
649
650
651
652
653
            bias = self.bias if not self.skip_bias_add else None
            # Matrix multiply.
            assert self.quant_method is not None
            output_parallel = self.quant_method.apply(self, input_, bias)
            if self.gather_output and self.tp_size > 1:
                # All-gather across the partitions.
                output = tensor_model_parallel_all_gather(output_parallel)
            else:
                output = output_parallel
            output_bias = self.bias if self.skip_bias_add else None
            if not self.return_bias:
                return output
            return output, output_bias
654

655
656
657
658
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size_per_partition}"
        s += f", bias={self.bias is not None}"
659
        s += f", tp_size={self.tp_size}"
660
661
662
        s += f", gather_output={self.gather_output}"
        return s

663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681

class MergedColumnParallelLinear(ColumnParallelLinear):
    """Packed linear layers with column parallelism.

    Similar to ColumnParallelLinear, but the weight matrix is concatenated
    along the output dimension. When the weight matrix is loaded, the
    different partitions are sharded separately.

    Args:
        input_size: input dimension of the linear layer.
        output_sizes: list of output dimensions of the linear layer.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make the output
                       available to all GPUs, otherwise, every GPU will have
                       its own output.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
682
        quant_config: Quantization configure.
683
684
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
685
        return_bias: If true, return bias together with outputs in forward pass.
686
687
        disable_tp: If true, all weights matrix won't be sharded, this layer
                    will be treated as a "Replicated" MergedLinear.
688
689
    """

690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
    def forward(
        self, input_,
        rms_weight: Optional[torch.Tensor] = None,
        residual: Optional[torch.Tensor] = None,
        update_hd: Optional[bool] = True
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
        if envs.USE_FUSED_RMS_QUANT and rms_weight is not None:
            input_quant_args = None
            assert residual is not None and rms_weight is not None 
            i_q, _scales = lm_faster_rmsquant(input=input_,
                                        rms_weight=rms_weight,
                                        epsilon=self.eps,
                                        quant_dtype=torch.int8,
                                        residual=residual,
                                        update_input=update_hd)
            
            new_residual = residual
            input_quant_args = [i_q, _scales]
            
            
            bias = self.bias if not self.skip_bias_add else None
            assert self.quant_method is not None
            output_parallel = self.quant_method.apply(self, input_, bias, input_quant_args)
            
            if self.gather_output:
                # All-gather across the partitions.
                output = tensor_model_parallel_all_gather(output_parallel)
            else:
                output = output_parallel
            output_bias = self.bias if self.skip_bias_add else None
            if not self.return_bias:
                return output
            return output, new_residual, output_bias
        else: # not USE_FUSED_RMS_QUANT
            bias = self.bias if not self.skip_bias_add else None

            assert self.quant_method is not None
            output_parallel = self.quant_method.apply(self, input_, bias)
            if self.gather_output and self.tp_size > 1:
                # All-gather across the partitions.
                output = tensor_model_parallel_all_gather(output_parallel)
            else:
                output = output_parallel
            output_bias = self.bias if self.skip_bias_add else None
            if not self.return_bias:
                return output
            return output, output_bias

738
739
740
741
742
743
744
745
746
    def __init__(
        self,
        input_size: int,
        output_sizes: list[int],
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
747
        eps: Optional[float] = 1e-6,
748
749
750
        prefix: str = "",
        *,
        return_bias: bool = True,
751
        disable_tp: bool = False,
752
    ):
753
        self.eps = eps
754
        self.output_sizes = output_sizes
755
756
757
758
        self.tp_size = (get_tensor_model_parallel_world_size()
                        if not disable_tp else 1)
        self.tp_rank = (get_tensor_model_parallel_rank()
                        if not disable_tp else 0)
759
760
761

        assert all(output_size % self.tp_size == 0
                   for output_size in output_sizes)
762
763
764
765
766
767
        super().__init__(input_size=input_size,
                         output_size=sum(output_sizes),
                         bias=bias,
                         gather_output=gather_output,
                         skip_bias_add=skip_bias_add,
                         params_dtype=params_dtype,
768
                         quant_config=quant_config,
769
                         prefix=prefix,
770
771
                         return_bias=return_bias,
                         disable_tp=disable_tp)
772
773
774
775
776

    def weight_loader(self,
                      param: Parameter,
                      loaded_weight: torch.Tensor,
                      loaded_shard_id: Optional[int] = None):
James Fleming's avatar
James Fleming committed
777

778
779
780
781
782
        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
783
784
785
786
787
788
789
790
            if loaded_shard_id is not None:
                param.data[loaded_shard_id].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
                param.shard_weight_type = {
                    i: loaded_weight.item()
                    for i, _ in enumerate(self.output_sizes)
                }
791
792
            return

793
        if is_gguf_weight:
794

795
            output_dim = getattr(param, "output_dim", None)
796
797
            shard_size = loaded_weight.size(output_dim) // self.tp_size
            start_idx = self.tp_rank * shard_size
798

799
800
801
802
803
804
805
            if loaded_shard_id is not None:
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
806

807
808
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
809
810
        # Special case for per-tensor scale to load scalar into fused array.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
811
        is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
812

813
        if loaded_shard_id is None:
814
815
            # Loaded weight is already fused on disk (mlp).
            # (e.g., Phi-3's gate_up_proj).
816
            if output_dim is None:
817
                if needs_scalar_to_array:
818
819
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
                        param_data, loaded_weight, 0)
820

821
822
823
824
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            current_shard_offset = 0
825
826
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)
827
            shard_offsets: list[tuple[int, int, int]] = []
828
829
830
831
832
            for i, output_size in enumerate(self.output_sizes):
                shard_offsets.append((i, current_shard_offset, output_size))
                current_shard_offset += output_size
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
833
                # Special case for Quantization.
834
835
836
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
837
838
                    shard_size = shard_size // param.packed_factor
                    shard_offset = shard_offset // param.packed_factor
839
                    # Special case for Marlin.
840
841
842
                    shard_size, shard_offset = adjust_marlin_shard(
                        param, shard_size, shard_offset)

843
844
845
                shard_size, shard_offset = adjust_bitblas_shard(
                    param, shard_size, shard_offset)

846
                if use_bitsandbytes_4bit:
847
848
849
850
851
852
853
854
855
                    index = list(itertools.accumulate([0] + self.output_sizes))
                    orig_offsets = {
                        str(i): (index[i], size)
                        for i, size in enumerate(self.output_sizes)
                    }
                    orig_offsets["total"] = (self.output_size, 0)
                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
                        param, orig_offsets, str(shard_id))

856
857
858
859
860
861
862
                loaded_weight_shard = loaded_weight.narrow(
                    output_dim, shard_offset, shard_size)
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id < len(self.output_sizes)
        if output_dim is not None:
863
864
865
            shard_offset = (sum(self.output_sizes[:loaded_shard_id]) //
                            self.tp_size)
            shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
866
            # Special case for quantization.
867
868
869
870
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
871
872
                shard_size = shard_size // param.packed_factor
                shard_offset = shard_offset // param.packed_factor
873
                # Special case for Marlin.
874
875
                shard_size, shard_offset = adjust_marlin_shard(
                    param, shard_size, shard_offset)
876
877
            shard_size, shard_offset = adjust_bitblas_shard(
                param, shard_size, shard_offset)
gaoqiong's avatar
gaoqiong committed
878

879
880
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)
881
882
883
884
885
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

886
            if use_bitsandbytes_4bit:
887
888
889
                shard_size = loaded_weight.shape[output_dim]
                shard_offset = loaded_weight.shape[output_dim] * \
                    loaded_shard_id
890
891
892
893
894
                    
            if not envs.VLLM_USE_NN or is_quantization:
                param_data = param_data.narrow(output_dim, shard_offset, shard_size)
            else:
                param_data = param_data.narrow(int(not(output_dim)), shard_offset, shard_size)
895

896
            start_idx = self.tp_rank * shard_size
897
            if not is_sharded_weight:
898
899
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)
900
901
902
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
903
904
                param_data, loaded_weight, loaded_shard_id)

905
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
906
907
908
909
910
911
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "MergedColumnParallelLinear, assume the weight is "
                    "the same for all partitions.")
912

913
914
915
        if envs.VLLM_USE_NN and not is_quantization:
            loaded_weight = loaded_weight.t()
            
gaoqiong's avatar
gaoqiong committed
916
917
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)
918

919
920
921
922
923
924
925
926
927
928
929
930
931
    def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
                                           loaded_weight: torch.Tensor):
        """
        Handle special case for models where MLP layers are already
        fused on disk. In this case, we have no shard id. This function
        determmines the shard id by splitting these layers and then calls
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """

        current_shard_offset = 0
932
        shard_offsets: list[tuple[int, int, int]] = []
933
934
935
936
937
938
939
940
        for i, output_size in enumerate(self.output_sizes):
            shard_offsets.append((i, current_shard_offset, output_size))
            current_shard_offset += output_size

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
941
942
            if isinstance(param, (PackedColumnParameter, PackedvLLMParameter
                                  )) and param.packed_dim == param.output_dim:
943
944
                shard_size, shard_offset = \
                    param.adjust_shard_indexes_for_packing(
945
946
947
948
949
950
951
952
953
954
955
956
                    shard_size=shard_size, shard_offset=shard_offset)

            loaded_weight_shard = loaded_weight.narrow(param.output_dim,
                                                       shard_offset,
                                                       shard_size)
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

    def weight_loader_v2(self,
                         param: BasevLLMParameter,
                         loaded_weight: torch.Tensor,
                         loaded_shard_id: Optional[int] = None):
        if loaded_shard_id is None:
957
958
959
960
            if isinstance(param, PerTensorScaleParameter):
                param.load_merged_column_weight(loaded_weight=loaded_weight,
                                                shard_id=0)
                return
961
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
962
                param.load_merged_column_weight(loaded_weight=loaded_weight)
963
                return
964
            # TODO: @dsikka - move to parameter.py
965
966
967
968
969
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id < len(self.output_sizes)

970
971
972
        if isinstance(param, BlockQuantScaleParameter):
            from vllm.model_executor.layers.quantization.fp8 import (
                Fp8LinearMethod, Fp8MoEMethod)
973
974
975
            
            from vllm.model_executor.layers.quantization.blockwise_int8 import (
                BlockInt8LinearMethod, BlockInt8MoEMethod)
976
977
            assert self.quant_method is not None
            assert isinstance(self.quant_method,
978
                              (Fp8LinearMethod, Fp8MoEMethod, BlockInt8LinearMethod, BlockInt8MoEMethod))
979
980
981
982
983
            weight_block_size = self.quant_method.quant_config.weight_block_size
            assert weight_block_size is not None
            block_n, _ = weight_block_size[0], weight_block_size[1]
            shard_offset = (
                (sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) //
984
                block_n) // self.tp_size
985
            shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) //
986
                          block_n // self.tp_size)
987
        else:
988
989
990
            shard_offset = sum(
                self.output_sizes[:loaded_shard_id]) // self.tp_size
            shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
991
992
993
994

        param.load_merged_column_weight(loaded_weight=loaded_weight,
                                        shard_id=loaded_shard_id,
                                        shard_offset=shard_offset,
995
996
                                        shard_size=shard_size,
                                        tp_rank=self.tp_rank)
997

998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019

class QKVParallelLinear(ColumnParallelLinear):
    """Linear layers for the attention's QKV transformation.

    Linear layers for the linear transformation of the query, key, and value
    vectors in the attention layer. The weight matrix is concatenated along
    the output dimension. The layer is parallelized along the head dimension.
    When the number of key/value heads is smaller than the number of query
    heads (e.g., multi-query/grouped-query attention), the key/value head may
    be replicated while the query heads are partitioned.

    Args:
        hidden_size: input hidden state size of the transformer.
        head_size: size of each attention head.
        total_num_heads: total number of attention query heads.
        total_num_kv_heads: total number of attention key/value heads. If
                            None, assume total_num_kv_heads = total_num_heads.
        bias: If true, add bias.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
1020
        quant_config: Quantization configure.
1021
1022
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
1023
        return_bias: If true, return bias together with outputs in forward pass.
1024
        disable_tp: If true, weights matrix won't be sharded through tp rank.
1025
1026
    """

1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
    def __init__(
        self,
        hidden_size: int,
        head_size: int,
        total_num_heads: int,
        total_num_kv_heads: Optional[int] = None,
        bias: bool = True,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        *,
        return_bias: bool = True,
1040
        disable_tp: bool = False,
1041
    ):
1042
1043
1044
1045
1046
1047
1048
        self.hidden_size = hidden_size
        self.head_size = head_size
        self.total_num_heads = total_num_heads
        if total_num_kv_heads is None:
            total_num_kv_heads = total_num_heads
        self.total_num_kv_heads = total_num_kv_heads
        # Divide the weight matrix along the last dimension.
1049
1050
        tp_size = (get_tensor_model_parallel_world_size()
                   if not disable_tp else 1)
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
        self.num_heads = divide(self.total_num_heads, tp_size)
        if tp_size >= self.total_num_kv_heads:
            self.num_kv_heads = 1
            self.num_kv_head_replicas = divide(tp_size,
                                               self.total_num_kv_heads)
        else:
            self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
            self.num_kv_head_replicas = 1
        input_size = self.hidden_size
        output_size = (self.num_heads +
                       2 * self.num_kv_heads) * tp_size * self.head_size
1062
1063
1064
1065
        self.output_sizes = [
            self.num_heads * self.head_size * tp_size,  # q_proj
            self.num_kv_heads * self.head_size * tp_size,  # k_proj
            self.num_kv_heads * self.head_size * tp_size,  # v_proj 
James Fleming's avatar
James Fleming committed
1066
        ]
gaoqiong's avatar
gaoqiong committed
1067

1068
1069
1070
1071
1072
1073
        super().__init__(input_size=input_size,
                         output_size=output_size,
                         bias=bias,
                         gather_output=False,
                         skip_bias_add=skip_bias_add,
                         params_dtype=params_dtype,
1074
                         quant_config=quant_config,
1075
                         prefix=prefix,
1076
1077
                         return_bias=return_bias,
                         disable_tp=disable_tp)
1078

1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
    def _get_shard_offset_mapping(self, loaded_shard_id: str):
        shard_offset_mapping = {
            "q": 0,
            "k": self.num_heads * self.head_size,
            "v": (self.num_heads + self.num_kv_heads) * self.head_size,
            "total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size
        }
        return shard_offset_mapping.get(loaded_shard_id)

    def _get_shard_size_mapping(self, loaded_shard_id: str):
        shard_size_mapping = {
            "q": self.num_heads * self.head_size,
            "k": self.num_kv_heads * self.head_size,
            "v": self.num_kv_heads * self.head_size,
        }
        return shard_size_mapping.get(loaded_shard_id)

    def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
                                           loaded_weight: torch.Tensor):
        """
        Handle special case for models where QKV layers are already 
        fused on disk. In this case, we have no shard id. This function
        determmines the shard id by splitting these layers and then calls
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """
        shard_offsets = [
            # (shard_id, shard_offset, shard_size)
            ("q", 0, self.total_num_heads * self.head_size),
            ("k", self.total_num_heads * self.head_size,
             self.total_num_kv_heads * self.head_size),
            ("v",
             (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
             self.total_num_kv_heads * self.head_size),
        ]

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
1121
1122
            if isinstance(param, (PackedColumnParameter, PackedvLLMParameter
                                  )) and param.packed_dim == param.output_dim:
1123
1124
                shard_size, shard_offset = \
                    param.adjust_shard_indexes_for_packing(
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
                    shard_size=shard_size, shard_offset=shard_offset)

            loaded_weight_shard = loaded_weight.narrow(param.output_dim,
                                                       shard_offset,
                                                       shard_size)
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

    def weight_loader_v2(self,
                         param: BasevLLMParameter,
                         loaded_weight: torch.Tensor,
                         loaded_shard_id: Optional[str] = None):
        if loaded_shard_id is None:  # special case for certain models
1137
            if isinstance(param, PerTensorScaleParameter):
1138
1139
1140
                param.load_qkv_weight(loaded_weight=loaded_weight,
                                      shard_id=0,
                                      tp_rank=self.tp_rank)
1141
                return
1142
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
1143
1144
                param.load_qkv_weight(loaded_weight=loaded_weight,
                                      tp_rank=self.tp_rank)
1145
                return
1146
            # TODO: @dsikka - move to parameter.py
1147
1148
1149
1150
1151
1152
1153
1154
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id in ["q", "k", "v"]

        shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
        shard_size = self._get_shard_size_mapping(loaded_shard_id)

1155
1156
1157
1158
1159
1160
1161
1162
1163
        # Note(simon): This is needed for Qwen3's fp8 quantization.
        if isinstance(param, BlockQuantScaleParameter):
            assert self.quant_method is not None
            assert hasattr(self.quant_method, "quant_config")
            weight_block_size = self.quant_method.quant_config.weight_block_size
            block_n, _ = weight_block_size[0], weight_block_size[1]
            shard_offset = (shard_offset + block_n - 1) // block_n
            shard_size = (shard_size + block_n - 1) // block_n

1164
1165
1166
1167
        param.load_qkv_weight(loaded_weight=loaded_weight,
                              num_heads=self.num_kv_head_replicas,
                              shard_id=loaded_shard_id,
                              shard_offset=shard_offset,
1168
1169
                              shard_size=shard_size,
                              tp_rank=self.tp_rank)
1170

1171
1172
1173
1174
    def weight_loader(self,
                      param: Parameter,
                      loaded_weight: torch.Tensor,
                      loaded_shard_id: Optional[str] = None):
1175
1176
1177
1178
1179

        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
1180
        if is_gguf_weight_type:
1181
            idx_map = {"q": 0, "k": 1, "v": 2}
1182
1183
1184
1185
1186
1187
1188
1189
            if loaded_shard_id is not None:
                param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
                param.shard_weight_type = {
                    k: loaded_weight.item()
                    for k in idx_map
                }
1190
1191
            return

1192
1193
        if is_gguf_weight:
            output_dim = getattr(param, "output_dim", None)
1194
1195
            shard_size = loaded_weight.size(output_dim) // self.tp_size
            start_idx = self.tp_rank * shard_size
1196

1197
1198
1199
1200
1201
1202
1203
            if loaded_shard_id is not None:
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
1204

1205
1206
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
1207

1208
1209
        # Special case for per-tensor scales in fused case.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
1210
        is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
1211

1212
        if loaded_shard_id is None:
1213
1214
            # Loaded weight is already fused on disk (qkv).
            # (e.g., Phi-3's qkv_proj).
1215
            if output_dim is None:
1216
                if needs_scalar_to_array:
1217
1218
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
                        param_data, loaded_weight, 0)
1219

1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            shard_offsets = [
                # (shard_id, shard_offset, shard_size)
                ("q", 0, self.total_num_heads * self.head_size),
                ("k", self.total_num_heads * self.head_size,
                 self.total_num_kv_heads * self.head_size),
                ("v", (self.total_num_heads + self.total_num_kv_heads) *
                 self.head_size, self.total_num_kv_heads * self.head_size),
            ]
1231
1232
1233
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)

1234
1235
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
1236
                # Special case for Quantized Weights.
1237
1238
1239
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
1240
1241
                    shard_size = shard_size // param.packed_factor
                    shard_offset = shard_offset // param.packed_factor
1242

1243
                    # Special case for Marlin.
1244
1245
1246
                    shard_size, shard_offset = adjust_marlin_shard(
                        param, shard_size, shard_offset)

1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
                if use_bitsandbytes_4bit:
                    orig_qkv_offsets = {
                        "q": (0, self.total_num_heads * self.head_size),
                        "k": (self.total_num_heads * self.head_size,
                              self.total_num_kv_heads * self.head_size),
                        "v":
                        ((self.total_num_heads + self.total_num_kv_heads) *
                         self.head_size,
                         self.total_num_kv_heads * self.head_size),
                        "total":
                        ((self.total_num_heads + 2 * self.total_num_kv_heads) *
                         self.head_size, 0)
                    }

                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
                        param, orig_qkv_offsets, shard_id)

1264
1265
1266
1267
1268
1269
                loaded_weight_shard = loaded_weight.narrow(
                    output_dim, shard_offset, shard_size)
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id in ["q", "k", "v"]
1270
1271

        # If output dim is defined, use the default loading process.
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
        if output_dim is not None:
            if loaded_shard_id == "q":
                shard_offset = 0
                shard_size = self.num_heads * self.head_size
            elif loaded_shard_id == "k":
                shard_offset = self.num_heads * self.head_size
                shard_size = self.num_kv_heads * self.head_size
            elif loaded_shard_id == "v":
                shard_offset = (self.num_heads +
                                self.num_kv_heads) * self.head_size
                shard_size = self.num_kv_heads * self.head_size
1283
            # Special case for Quantized Weights.
1284
1285
1286
1287
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
1288
1289
                shard_size = shard_size // param.packed_factor
                shard_offset = shard_offset // param.packed_factor
1290

1291
                # Special case for Marlin.
1292
1293
1294
                shard_size, shard_offset = adjust_marlin_shard(
                    param, shard_size, shard_offset)

1295
1296
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)
1297
1298
1299
1300
1301
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

1302
            if use_bitsandbytes_4bit:
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
                orig_qkv_offsets = {
                    "q": (0, self.num_heads * self.head_size),
                    "k": (self.num_heads * self.head_size,
                          self.num_kv_heads * self.head_size),
                    "v":
                    ((self.num_heads + self.num_kv_heads) * self.head_size,
                     self.num_kv_heads * self.head_size),
                    "total":
                    ((self.num_heads + 2 * self.num_kv_heads) * self.head_size,
                     0)
                }
1314
                shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
1315
                    param, orig_qkv_offsets, loaded_shard_id)
gaoqiong's avatar
gaoqiong committed
1316

1317
1318
1319
1320
1321
1322
1323
            if not envs.VLLM_USE_NN or len(param_data.shape)==1 or is_quantization:
                param_data = param_data.narrow(output_dim, shard_offset,
                                               shard_size)
            else:
                param_data = param_data.narrow(int(not(output_dim)), shard_offset,
                                               shard_size)
                
zhuwenwen's avatar
zhuwenwen committed
1324
            if loaded_shard_id == "q":
1325
                shard_id = self.tp_rank
1326
            else:
1327
                shard_id = self.tp_rank // self.num_kv_head_replicas
1328
            start_idx = shard_id * shard_size
1329

1330
            if not is_sharded_weight:
1331
1332
1333
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)

1334
1335
1336
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
1337
                param_data, loaded_weight, loaded_shard_id)
1338
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
1339
1340
1341
1342
1343
1344
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "QKVParallelLinear, assume the weight is the same "
                    "for all partitions.")
gaoqiong's avatar
gaoqiong committed
1345

1346
1347
1348
        if envs.VLLM_USE_NN and not is_quantization:
            loaded_weight = loaded_weight.t()
            
gaoqiong's avatar
gaoqiong committed
1349
1350
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)
1351
1352


1353
@CustomOp.register("row_parallel_linear")
1354
class RowParallelLinear(LinearBase):
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
    """Linear layer with row parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its first dimension and X along its second dimension as:
               -   -
              | A_1 |
              | .   |
          A = | .   |        X = [X_1, ..., X_p]
              | .   |
              | A_p |
               -   -
    Arguments:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias. Note that bias is not parallelized.
        input_is_parallel: If true, we assume that the input is already
                           split across the GPUs and we do not split
                           again.
        skip_bias_add: This was added to enable performance optimization where
                       bias can be fused with other element-wise operations.
                       We skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
1377
1378
1379
        reduce_results: If true, call all-reduce on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y = X_iA_i
1380
        quant_config: Quantization configure.
1381
1382
1383
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.down_proj)
        return_bias: If true, return bias together with outputs in forward pass.
1384
        disable_tp: If true, weights matrix won't be sharded through tp rank.
1385
1386
    """

1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        input_is_parallel: bool = True,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        reduce_results: bool = True,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        *,
        return_bias: bool = True,
1400
        disable_tp: bool = False,
1401
    ):
1402
        # Divide the weight matrix along the first dimension.
1403
1404
1405
1406
        self.tp_rank = (get_tensor_model_parallel_rank()
                        if not disable_tp else 0)
        self.tp_size = (get_tensor_model_parallel_world_size()
                        if not disable_tp else 1)
1407
1408
1409
1410
        self.input_size_per_partition = divide(input_size, self.tp_size)
        self.output_size_per_partition = output_size
        self.output_partition_sizes = [output_size]

1411
1412
1413
1414
1415
1416
        super().__init__(input_size,
                         output_size,
                         skip_bias_add,
                         params_dtype,
                         quant_config,
                         prefix,
1417
1418
                         return_bias=return_bias,
                         disable_tp=disable_tp)
1419

1420
1421
1422
        self.input_is_parallel = input_is_parallel
        self.reduce_results = reduce_results

1423
        assert self.quant_method is not None
1424
1425
1426
        self.quant_method.create_weights(
            layer=self,
            input_size_per_partition=self.input_size_per_partition,
1427
            output_partition_sizes=self.output_partition_sizes,
1428
1429
1430
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
1431
1432
            weight_loader=(
                self.weight_loader_v2 if self.quant_method.__class__.__name__
1433
                in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
1434
1435
1436
1437
1438
1439
        if not reduce_results and (bias and not skip_bias_add):
            raise ValueError("When not reduce the results, adding bias to the "
                             "results can lead to incorrect results")

        if bias:
            self.bias = Parameter(
1440
                torch.empty(self.output_size, dtype=params_dtype))
1441
1442
1443
1444
1445
1446
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
        else:
            self.register_parameter("bias", None)
1447

1448
        from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce
1449
        self.tbo_all_reduce = tbo_all_reduce
1450

1451
        self.update_param_tp_status()
1452
1453
1454

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        input_dim = getattr(param, "input_dim", None)
1455
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1456
1457
1458
1459
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
1460
1461
1462
1463
1464
1465
1466
1467
1468

        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
1469
1470
            weight_shape = list(loaded_weight.shape)
            if input_dim:
1471
1472
                weight_shape[input_dim] = (weight_shape[input_dim] //
                                           self.tp_size)
1473
            param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
1474
1475
            
        is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
1476

1477
        param_data = param.data
1478
        if input_dim is not None and not is_sharded_weight:
1479
1480
1481
1482
            if not envs.VLLM_USE_NN or is_quantization:
                shard_size = param_data.shape[input_dim]
            else:
                shard_size = param_data.shape[int(not(input_dim))]
1483
            start_idx = self.tp_rank * shard_size
1484
1485
            loaded_weight = loaded_weight.narrow(input_dim, start_idx,
                                                 shard_size)
1486

1487
1488
1489
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
1490
1491
            loaded_weight = loaded_weight.reshape(1)

1492
1493
1494
        if envs.VLLM_USE_NN and not is_quantization:
            loaded_weight = loaded_weight.t()
            
1495
1496
1497
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

1498
1499
    def weight_loader_v2(self, param: BasevLLMParameter,
                         loaded_weight: torch.Tensor):
1500
1501
1502
1503
1504
1505
1506

        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)

1507
1508
        param.load_row_parallel_weight(loaded_weight=loaded_weight)

1509
    def forward(
1510
1511
        self, input_,
        use_fused_silu_mul_quant: Optional[bool] = False
1512
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
1513
1514
1515
1516
1517
        if self.input_is_parallel:
            input_parallel = input_
        else:
            splitted_input = split_tensor_along_last_dim(
                input_, num_partitions=self.tp_size)
1518
            input_parallel = splitted_input[self.tp_rank].contiguous()
1519
1520

        # Matrix multiply.
1521
        assert self.quant_method is not None
1522
1523
1524
        # Only fuse bias add into GEMM for rank 0 (this ensures that
        # bias will not get added more than once in TP>1 case)
        bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
        if use_fused_silu_mul_quant:
            xq, xs = lm_fuse_silu_mul_quant(input_parallel)
            
            silu_quant_args = [xq, xs]
            output_parallel = self.quant_method.apply(self,
                                                      input_parallel,
                                                      bias=bias_,
                                                      silu_quant_args=silu_quant_args)
        else:
            output_parallel = self.quant_method.apply(self,
                                                      input_parallel,
                                                      bias=bias_)
1537
        if self.reduce_results and self.tp_size > 1:
1538
            if envs.VLLM_ENABLE_TBO:
1539
1540
1541
                output = self.tbo_all_reduce(output_parallel)
            else:
                output = tensor_model_parallel_all_reduce(output_parallel)
1542
        else:
1543
1544
1545
            output = output_parallel

        output_bias = self.bias if self.skip_bias_add else None
1546

1547
1548
        if not self.return_bias:
            return output
1549
        return output, output_bias
1550
1551

    def extra_repr(self) -> str:
1552
        s = f"in_features={self.input_size_per_partition}"
1553
1554
1555
1556
1557
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={self.tp_size}"
        s += f", reduce_results={self.reduce_results}"
        return s
1558
1559


1560
@CustomOp.register("qkv_cross_parallel_linear")
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
class QKVCrossParallelLinear(LinearBase):
    """Linear layers for efficient cross-attention's QKV transformation.

    Args:
        hidden_size: input hidden state size of the transformer.
        head_size: size of each attention head.
        total_num_heads: total number of attention query heads.
        total_num_kv_heads: total number of attention key/value heads. If
                            None, assume total_num_kv_heads = total_num_heads.
        bias: If true, add bias.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
        quant_config: Quantization configure.
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
    """
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589

    def __init__(self,
                 hidden_size: int,
                 head_size: int,
                 total_num_heads: int,
                 total_num_kv_heads: Optional[int] = None,
                 bias: bool = True,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
        # input_size and output_size are not used, just for alignment
        input_size = hidden_size
        output_size = (total_num_heads + (total_num_kv_heads or 0)) * head_size
        super().__init__(input_size=input_size,
                         output_size=output_size,
                         skip_bias_add=skip_bias_add,
                         params_dtype=params_dtype,
                         quant_config=quant_config,
                         prefix=prefix)

        self.quant_config = quant_config

1602
        # Empty placeholders for loading as a single module.
1603
1604
1605
1606
1607
1608
1609
1610
1611
        placeholder_size = 0
        assert self.quant_method is not None
        self.quant_method.create_weights(self,
                                         placeholder_size, [placeholder_size],
                                         placeholder_size,
                                         placeholder_size,
                                         self.params_dtype,
                                         weight_loader=self.weight_loader)

1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
        # Use a dictionary to avoid submodules parameters auto-registration:
        # drop-in replacement for a `QKVParallelLinear` module.
        self.proj = dict()
        self.proj["q_proj_decoder"] = ColumnParallelLinear(
            input_size=hidden_size,
            output_size=total_num_heads * head_size,
            bias=bias,
            quant_config=quant_config,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            prefix=f"{prefix}.q_proj_decoder")

        self.proj["kv_proj_encoder"] = QKVParallelLinear(
            hidden_size=hidden_size,
            head_size=head_size,
            total_num_heads=0,
            total_num_kv_heads=total_num_kv_heads,
            bias=bias,
            quant_config=quant_config,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            prefix=f"{prefix}.kv_proj_encoder")

        # `kv_proj_encoder.num_kv_heads` accounts for sharding with tp>1.
1636
        self.q_size = self.q_proj_decoder.output_size_per_partition
1637
1638
1639
1640
1641
        self.kv_size = self.kv_proj_encoder.num_kv_heads * head_size

        if bias:
            self.bias = torch.nn.Parameter()
            set_weight_attrs(self.bias, {
1642
                "output_dim": 0,
1643
                "weight_loader": self.weight_loader_v1,
1644
            })
1645
1646
        else:
            self.bias = None
1647

1648
1649
1650
1651
1652
    def process_weights_after_loading(self):
        for layer in self.proj.values():
            if self.quant_method is not None:
                self.quant_method.process_weights_after_loading(layer)

1653
    @property
1654
1655
1656
    def q_proj_decoder(self) -> ColumnParallelLinear:
        layer = self.proj["q_proj_decoder"]
        for name, param in self.named_parameters():
1657
1658
1659
1660
1661
            target_param = getattr(layer, name, None)
            if target_param is not None:
                self.sync_weight_attrs(param,
                                       target_param,
                                       mode="q_proj_decoder")
1662
        return layer
1663
1664

    @property
1665
1666
1667
    def kv_proj_encoder(self) -> QKVParallelLinear:
        layer = self.proj["kv_proj_encoder"]
        for name, param in self.named_parameters():
1668
1669
1670
1671
1672
            target_param = getattr(layer, name, None)
            if target_param is not None:
                self.sync_weight_attrs(param,
                                       target_param,
                                       mode="kv_proj_encoder")
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
        return layer

    def sync_weight_attrs(
        self,
        src_param: nn.Parameter,
        tgt_param: nn.Parameter,
        mode: Literal["q_proj_decoder", "kv_proj_encoder"],
    ):
        missing_attrs_dict = {
            k: getattr(src_param, k)
1683
1684
            for k in (set(vars(src_param).keys()) -
                      set(vars(tgt_param).keys()))
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
        }
        # TODO(Isotr0py): handle bitsandbytes 8bit
        use_bitsandbytes_4bit = getattr(src_param, "use_bitsandbytes_4bit",
                                        False)
        if (missing_attrs_dict and use_bitsandbytes_4bit):
            q_proj_attrs, kv_proj_attrs = left_shift_bitsandbytes_4bit_shard(
                missing_attrs_dict)
            if mode == "q_proj_decoder":
                set_weight_attrs(tgt_param, q_proj_attrs)
            elif mode == "kv_proj_encoder":
                set_weight_attrs(tgt_param, kv_proj_attrs)
        else:
            set_weight_attrs(tgt_param, missing_attrs_dict)
1698

1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
    def _is_same_param(
        self,
        src_param: torch.nn.Parameter,
        map_param: torch.nn.Parameter,
    ) -> bool:
        """Check if two parameters are exactly pointing to same things."""
        # ignore weight_loader because it's always different
        key_to_ignore = ["weight_loader", "_weight_loader"]
        has_same_type_name = type(src_param) is type(map_param)
        src_param_attrs = {
            k: v
            for k, v in src_param.__dict__.items() if k not in key_to_ignore
        }
        map_param_attrs = {
            k: v
            for k, v in map_param.__dict__.items() if k not in key_to_ignore
        }
        has_same_attrs = src_param_attrs == map_param_attrs
        return has_same_type_name and has_same_attrs

    def select_proj_params(
        self,
        layer: nn.Module,
        param: nn.Parameter,
    ) -> nn.Parameter:
        """
        Given the placeholder param, 
        return the corresponding param in the proj layers.
        """
        target_param_list = [
            v for _, v in layer.named_parameters()
            if self._is_same_param(param, v)
        ]
        assert len(target_param_list) == 1
        target_param = target_param_list[0]
        return target_param

    def forward(  # type: ignore[override]
        self,
        decoder_hidden_states: torch.Tensor,
        encoder_hidden_states: torch.Tensor,
    ) -> tuple[torch.Tensor, ...]:
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
        q, _ = self.q_proj_decoder(decoder_hidden_states)
        if encoder_hidden_states is None:
            # Encoder KV already cached.
            k = None
            v = None
        else:
            # Prefill phase, encoder KV cached here.
            kv_enc, _ = self.kv_proj_encoder(encoder_hidden_states)
            # Split kv in half
            k, v = kv_enc.split(self.kv_size, dim=-1)
        return q, k, v

1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
    def weight_loader_v1(self,
                         param: torch.nn.Parameter,
                         loaded_weight: torch.Tensor,
                         loaded_shard_id: Optional[str] = None):
        # just like all other parameters, does not yet
        # support loading bias with weight_loader_v2
        layer = (self.q_proj_decoder
                 if loaded_shard_id == "q" else self.kv_proj_encoder)
        target_param = self.select_proj_params(layer, param)
        shard_id_args = (loaded_shard_id, ) if loaded_shard_id != "q" else ()
        layer.weight_loader(target_param, loaded_weight, *shard_id_args)

1765
1766
1767
1768
1769
1770
1771
1772
    def weight_loader(self,
                      param: torch.nn.Parameter,
                      loaded_weight: torch.Tensor,
                      loaded_shard_id: Optional[str] = None):
        layer = (self.q_proj_decoder
                 if loaded_shard_id == "q" else self.kv_proj_encoder)
        target_param = self.select_proj_params(layer, param)
        shard_id_args = (loaded_shard_id, ) if loaded_shard_id != "q" else ()
1773
1774
1775
1776
        if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED:
            layer.weight_loader_v2(target_param, loaded_weight, *shard_id_args)
        else:
            layer.weight_loader(target_param, loaded_weight, *shard_id_args)
1777
1778
1779

    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
1780
        s += f", q_size={self.q_size}"
1781
1782
1783
1784
        s += f", kv_size={self.kv_size}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={get_tensor_model_parallel_world_size()}"
        s += ", gather_output=False"
1785
        return s