linear.py 77 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
from abc import abstractmethod
6
from typing import Any, Literal, Optional, Union
7
import vllm.envs as envs
8
import torch
9
import torch.nn as nn
10
from torch.nn.parameter import Parameter, UninitializedParameter
11

12
13
14
15
16
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size,
                              split_tensor_along_last_dim,
                              tensor_model_parallel_all_gather,
                              tensor_model_parallel_all_reduce)
17
from vllm.logger import init_logger
18
from vllm.model_executor.custom_op import CustomOp
19
20
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig, QuantizeMethodBase)
21
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
22
# yapf: disable
23
from vllm.model_executor.parameter import (BasevLLMParameter,
24
                                           BlockQuantScaleParameter,
25
                                           ModelWeightParameter,
26
                                           PackedColumnParameter,
27
                                           PackedvLLMParameter,
28
29
                                           PerTensorScaleParameter,
                                           RowvLLMParameter)
30
# yapf: enable
31
from vllm.model_executor.utils import set_weight_attrs
32
from vllm.platforms import current_platform
33
from vllm.utils import GiB_bytes
gaoqiong's avatar
gaoqiong committed
34

zhuwenwen's avatar
zhuwenwen committed
35
import os
36
from vllm.model_executor.utils import gemm_bank_conf
37

38
39
40
41
42
if envs.USE_FUSED_RMS_QUANT:
    try:
        from lmslim.quantize.quant_ops import lm_faster_rmsquant
    except Exception as e:
        print(f"Error: Import fused rmsquant error: {e}") 
43
44
45
46
47
48
49
        
if envs.USE_FUSED_SILU_MUL_QUANT:        
    try:
        # from lightop import fuse_silu_mul_quant
        from lmslim.quantize.quant_ops import lm_fuse_silu_mul_quant
    except Exception as e:
        print(f"Error: Import fused silu_mul_qunat error: {e}")
50

51
52
logger = init_logger(__name__)

53
WEIGHT_LOADER_V2_SUPPORTED = [
54
    "UnquantizedLinearMethod",
55
    "CompressedTensorsLinearMethod",
56
    "CompressedTensorsLinearTransformMethod",
57
58
    "BitBLASLinearMethod",
    "GPTQBitBLASLinearMethod",
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
    "AWQMarlinLinearMethod",
    "AWQLinearMethod",
    "GPTQMarlinLinearMethod",
    "Fp8LinearMethod",
    "MarlinLinearMethod",
    "GPTQMarlin24LinearMethod",
    "TPUInt8LinearMethod",
    "GPTQLinearMethod",
    "FBGEMMFp8LinearMethod",
    "ModelOptFp8LinearMethod",
    "IPEXAWQLinearMethod",
    "IPEXGPTQLinearMethod",
    "HQQMarlinMethod",
    "QuarkLinearMethod",
    "ModelOptNvFp4LinearMethod",
74
    "PetitNvFp4LinearMethod",
zhuwenwen's avatar
zhuwenwen committed
75
    "BlockInt8LinearMethod",
76
]
77

78

79
80
81
82
83
84
85
86
87
def adjust_bitblas_shard(param, shard_size, shard_offset):
    bitblas_tile_size = getattr(param, "bitblas_tile_size", None)
    if bitblas_tile_size is not None:
        return (shard_size // bitblas_tile_size,
                shard_offset // bitblas_tile_size)

    return shard_size, shard_offset


88
89
90
91
92
93
94
95
def adjust_marlin_shard(param, shard_size, shard_offset):
    marlin_tile_size = getattr(param, "marlin_tile_size", None)
    if marlin_tile_size is None:
        return shard_size, shard_offset

    return shard_size * marlin_tile_size, shard_offset * marlin_tile_size


96
def adjust_bitsandbytes_4bit_shard(param: Parameter,
97
98
                                   shard_offsets: dict[str, tuple[int, int]],
                                   loaded_shard_id: str) -> tuple[int, int]:
99
100
    """Adjust the quantization offsets and sizes for BitsAndBytes sharding."""

101
102
    total, _ = shard_offsets["total"]
    orig_offset, orig_size = shard_offsets[loaded_shard_id]
103
104
105
106
107
108
109
110

    quantized_total = param.data.shape[0]
    quantized_offset = orig_offset * quantized_total // total
    quantized_size = orig_size * quantized_total // total

    return quantized_size, quantized_offset


111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
    """For fused modules (QKV and MLP) we have an array of length
    N that holds 1 scale for each "logical" matrix. So the param
    is an array of length N. The loaded_weight corresponds to 
    one of the shards on disk. Here, we slice the param based on 
    the shard_id for loading.
    """
    qkv_idxs = {"q": 0, "k": 1, "v": 2}

    if isinstance(shard_id, str):
        shard_id = qkv_idxs[shard_id]
    elif not isinstance(shard_id, int):
        raise ValueError(f"Unknown Shard Id {shard_id}")

    # AutoFP8 scales do not have a shape
    # compressed-tensors scales do have a shape
    if len(loaded_weight.shape) != 0:
        assert loaded_weight.shape[0] == 1
        loaded_weight = loaded_weight[0]

131
132
133
134
    if envs.VLLM_USE_NN:
        return param[shard_id], loaded_weight.t()
    else:
        return param[shard_id], loaded_weight
135
136


137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
# TODO(Isotr0py): We might need a more flexible structure to handle
# bitsandbytes shard offsets.
def left_shift_bitsandbytes_4bit_shard(bnb_weight_attrs: dict[str, Any]):
    """
    Separate the BitsAndBytes 4-bit shard.

    For example, given bnb weight attributes as below:
    {
        'bnb_shard_offsets': array([0, 4, 8, 16]), 
        'bnb_quant_state': {0: ..., 1: ..., 2: ...},
    }

    The function will return:
    {
        'bnb_shard_offsets': array([0, 4]), 
        'bnb_quant_state': {0: ...},
    }
    and
    {
        'bnb_shard_offsets': array([0, 4, 12]),
        'bnb_quant_state': {0: ..., 1: ...},
    }
    """
    shard_offsets = bnb_weight_attrs["bnb_shard_offsets"]
    offset_l = shard_offsets[:2]
    offset_r = shard_offsets[1:] - shard_offsets[1]
    quant_state_l = {0: bnb_weight_attrs["bnb_quant_state"][0]}
    quant_state_r = {
        i - 1: bnb_weight_attrs["bnb_quant_state"][i]
        for i in range(1,
                       len(shard_offsets) - 1)
    }
    left = dict(bnb_shard_offsets=offset_l, bnb_quant_state=quant_state_l)
    right = dict(bnb_shard_offsets=offset_r, bnb_quant_state=quant_state_r)
    return left, right


174
class LinearMethodBase(QuantizeMethodBase):
175
176
177
    """Base class for different (maybe quantized) linear methods."""

    @abstractmethod
178
179
    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
180
                       output_partition_sizes: list[int], input_size: int,
181
182
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
183
184
        """Create weights for a linear layer. 
           The weights will be set as attributes of the layer.
185

186
187
188
189
190
191
192
193
194
195
        Args:
            layer: The layer that is using the LinearMethodBase factory.
            input_size_per_partition: Size of the weight input dim on rank X.
            output_partition_sizes: Sizes of the output dim of each logical 
                weight on rank X. E.g., output_partition_sizes for QKVLinear
                is a list contains the width of Wq, Wk, Wv on rank X.
            input_size: Size of the input dim of the weight across all ranks.
            output_size: Size of the output dim of the weight across all ranks.
            params_dtype: Datatype of the parameters.
        """
196
197
198
        raise NotImplementedError

    @abstractmethod
199
200
201
202
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
203
204
        """Apply the weights in layer to the input tensor.
        Expects create_weights to have been called before on the layer."""
205
206
207
208
        raise NotImplementedError


class UnquantizedLinearMethod(LinearMethodBase):
209
    """Linear method without quantization."""
210
211
    
    def __init__(self):
zhuwenwen's avatar
zhuwenwen committed
212
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
213
        self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
gaoqiong's avatar
gaoqiong committed
214
        
215
216
    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
217
                       output_partition_sizes: list[int], input_size: int,
218
219
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
220
221
222
223
224
        # This method creates unquantized linear weights.
        # The weights are not quantized, and they are not sharded.
        # The amount of memory allocated for the weights is
        # sum(output_partition_sizes) * input_size_per_partition.
        try:
225
            weight_loader = extra_weight_attrs.pop("weight_loader")
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
            if envs.VLLM_USE_NN:
                weight = ModelWeightParameter(data=torch.empty(
                    input_size_per_partition,
                    sum(output_partition_sizes),
                    dtype=params_dtype),
                                            input_dim=1,
                                            output_dim=0,
                                            weight_loader=weight_loader)
            else:
                weight = ModelWeightParameter(data=torch.empty(
                    sum(output_partition_sizes),
                    input_size_per_partition,
                    dtype=params_dtype),
                                            input_dim=1,
                                            output_dim=0,
                                            weight_loader=weight_loader)
242
243
244
245
246
247
248
249
250
251
252
253
        except torch.cuda.OutOfMemoryError as e:
            logger.error("Failed to create unquantized linear weights: %s", e)
            if torch.cuda.is_available():
                logger.debug("CUDA device: %s", torch.cuda.current_device())
                logger.debug("Allocated: %.2f GiB",
                             torch.cuda.memory_allocated() / GiB_bytes)
                logger.debug("Reserved: %.2f GiB",
                             torch.cuda.memory_reserved() / GiB_bytes)
            raise RuntimeError(
                "Failed to create unquantized linear weights. "
                "This may be caused by insufficient memory to allocate "
                "the weight.") from e
254

255
256
        layer.register_parameter("weight", weight)
        set_weight_attrs(weight, extra_weight_attrs)
257

258
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
259
260
261
262
        if current_platform.is_cpu():
            from vllm.model_executor.layers.utils import (
                dispatch_cpu_unquantized_gemm)
            dispatch_cpu_unquantized_gemm(layer, remove_weight=True)
263

264
265
266
267
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
zhuwenwen's avatar
zhuwenwen committed
268
        if self.use_llama_nn:
269
270
            if gemm_bank_conf(layer.weight.shape[1] - 32) and os.environ['GEMM_PAD'] == '1':
                layer.weight = layer.weight[:,:-32]
271
                
zhuwenwen's avatar
zhuwenwen committed
272
            if bias is not None:
zhuwenwen's avatar
zhuwenwen committed
273
                if len(x.shape) == 2: 
274
                    return torch.addmm(bias, x, layer.weight)
zhuwenwen's avatar
zhuwenwen committed
275
                else:
276
                    return torch.matmul(x, layer.weight) + bias
zhuwenwen's avatar
zhuwenwen committed
277
            else:
278
                return torch.matmul(x, layer.weight)
zhuwenwen's avatar
zhuwenwen committed
279
        else:
280
281
282
283
            if envs.VLLM_USE_NN and x.shape[-1] == layer.weight.shape[0]:
                return dispatch_unquantized_gemm()(x, layer.weight.t(), bias)
            else:
                return dispatch_unquantized_gemm()(x, layer.weight, bias)
284

285

286
class LinearBase(CustomOp):
287
    """Base linear layer.
288
289
290
291
292
293

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
294
        quant_config: Quantization configure.
295
        prefix: Prefix for parameter names.
296
        return_bias: If true, return bias together with outputs in forward pass.
297
        disable_tp: If true, tensor parallelism will be disabled for this layer.
298
299
300
301
302
303
304
305
    """

    def __init__(
        self,
        input_size: int,
        output_size: int,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
306
        quant_config: Optional[QuantizationConfig] = None,
307
        prefix: str = "",
308
309
        *,
        return_bias: bool = True,
310
        disable_tp: bool = False,
311
312
313
314
315
316
317
318
319
320
    ):
        super().__init__()

        # Keep input parameters
        self.input_size = input_size
        self.output_size = output_size
        self.skip_bias_add = skip_bias_add
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype
321
322
        self.quant_config = quant_config
        self.prefix = prefix
323
        if quant_config is None:
324
325
            self.quant_method: Optional[
                QuantizeMethodBase] = UnquantizedLinearMethod()
326
        else:
327
328
            self.quant_method = quant_config.get_quant_method(self,
                                                              prefix=prefix)
329
        self.return_bias = return_bias
330
331
332
333
334
335
        self.disable_tp = disable_tp
        self.tp_rank = (get_tensor_model_parallel_rank()
                        if not disable_tp else 0)
        self.tp_size = (get_tensor_model_parallel_world_size()
                        if not disable_tp else 1)

336
    def update_param_tp_status(self):
337
338
339
340
        for param in self.parameters():
            if isinstance(param, BasevLLMParameter):
                param.tp_rank = self.tp_rank
                param.tp_size = self.tp_size
341
342


343
@CustomOp.register("replicated_linear")
344
345
346
347
348
349
350
351
352
353
class ReplicatedLinear(LinearBase):
    """Replicated linear layer.

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        bias: If true, add bias.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
        quant_config: Quantization configure.
354
355
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
356
        return_bias: If true, return bias together with outputs in forward pass.
357
        disable_tp: Take no effect for replicated linear layers.
358
359
    """

360
361
362
363
364
365
366
367
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
368
        eps: Optional[float] = 1e-6,
369
370
371
        prefix: str = "",
        *,
        return_bias: bool = True,
372
        disable_tp: bool = False,
373
    ):
374
375
376
377
378
379
        # If MergedReplicatedLinear, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = self.output_sizes
        else:
            self.output_partition_sizes = [output_size]

380
381
382
383
384
        super().__init__(input_size,
                         output_size,
                         skip_bias_add,
                         params_dtype,
                         quant_config,
385
                         prefix=prefix,
386
387
                         return_bias=return_bias,
                         disable_tp=disable_tp)
388
        self.eps = eps
389

390
391
        # All the linear layer supports quant method.
        assert self.quant_method is not None
392
        self.quant_method.create_weights(self,
393
394
                                         self.input_size,
                                         self.output_partition_sizes,
395
396
397
                                         self.input_size,
                                         self.output_size,
                                         self.params_dtype,
398
                                         weight_loader=self.weight_loader)
399

400
401
        if bias:
            self.bias = Parameter(
402
                torch.empty(self.output_size, dtype=self.params_dtype))
403
404
405
406
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
407
408
409
        else:
            self.register_parameter("bias", None)

410
411
412
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        # If the weight on disk does not have a shape, give it one
        # (such scales for AutoFp8).
413
414
415
416
417
418
419
420
421
422
423
        # Special case for GGUF

        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
            param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)

424
425
426
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)

427
428
429
430
        is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
        if envs.VLLM_USE_NN and not is_quantization:
            loaded_weight = loaded_weight.t()
            
431
432
433
        assert param.size() == loaded_weight.size(), (
            f"Tried to load weights of size {loaded_weight.size()}"
            f"to a parameter of size {param.size()}")
434
435
        param.data.copy_(loaded_weight)

436
    def forward(
437

438
        self,
439
440
441
442
443
        input_: torch.Tensor,
        rms_weight: Optional[torch.Tensor] = None,
        residual: Optional[torch.Tensor] = None,
        quant_args: Optional[list] = None,
        update_hd: Optional[bool] = True
444
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
        if envs.USE_FUSED_RMS_QUANT and (rms_weight is not None or quant_args is not None):
            if quant_args is not None:
                input_quant_args = quant_args
            
                bias = self.bias if not self.skip_bias_add else None
                assert self.quant_method is not None
                output = self.quant_method.apply(self, input_, bias, input_quant_args)
                output_bias = self.bias if self.skip_bias_add else None
                if not self.return_bias:
                    return output
                return output, output_bias

            else:
                i_q, _scales = lm_faster_rmsquant(input=input_,
                                                  rms_weight=rms_weight,
                                                  epsilon=self.eps,
                                                  quant_dtype=torch.int8,
                                                  residual=residual,
                                                  update_input=update_hd
                                                  )
            
                new_residual = residual
                input_quant_args = [i_q, _scales]
                
                bias = self.bias if not self.skip_bias_add else None
                assert self.quant_method is not None
                output = self.quant_method.apply(self, input_, bias, input_quant_args)
                output_bias = self.bias if self.skip_bias_add else None
                if not self.return_bias:
                    return output
                return output, new_residual, output_bias, input_quant_args

        else:
            bias = self.bias if not self.skip_bias_add else None
            assert self.quant_method is not None
            output = self.quant_method.apply(self, input_, bias)
            output_bias = self.bias if self.skip_bias_add else None
            if not self.return_bias:
                return output
            return output, output_bias
485

486
487
488
489
490
491
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        return s

492

493
@CustomOp.register("column_parallel_linear")
494
class ColumnParallelLinear(LinearBase):
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
    """Linear layer with column parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its second dimension as A = [A_1, ..., A_p].

    Args:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y_i = XA_i
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
511
        quant_config: Quantization configure.
James Fleming's avatar
James Fleming committed
512
513
        output_sizes: list of output sizes packed into one output, like for QKV
                       the list would be size 3.
514
        prefix: The name of the layer in the state dict, including all parents
515
                        (e.g. model.layers.0.qkv_proj) 
516
517
        return_bias: If true, return bias together with outputs in forward pass.
        disable_tp: If true, weights matrix won't be sharded through tp rank.
518
519
    """

520
521
522
523
524
525
526
527
528
529
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
        output_sizes: Optional[list[int]] = None,
530
        eps: Optional[float] = 1e-6,
531
532
533
        prefix: str = "",
        *,
        return_bias: bool = True,
534
        disable_tp: bool = False,
535
    ):
536
        # Divide the weight matrix along the last dimension.
537
538
539
540
        self.tp_rank = (get_tensor_model_parallel_rank()
                        if not disable_tp else 0)
        self.tp_size = (get_tensor_model_parallel_world_size()
                        if not disable_tp else 1)
541
542
        self.input_size_per_partition = input_size
        self.output_size_per_partition = divide(output_size, self.tp_size)
543
544
545
546
        self.output_partition_sizes = [self.output_size_per_partition]
        # If QKV or MergedColumn, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = [
547
                divide(output_size, self.tp_size)
548
549
550
                for output_size in self.output_sizes
            ]

551
552
553
554
555
556
        super().__init__(input_size,
                         output_size,
                         skip_bias_add,
                         params_dtype,
                         quant_config,
                         prefix,
557
558
                         return_bias=return_bias,
                         disable_tp=disable_tp)
559

560
        self.eps = eps
561
562
        self.gather_output = gather_output

James Fleming's avatar
James Fleming committed
563
564
        if output_sizes is None:
            output_sizes = [output_size]
565

566
        assert self.quant_method is not None
567
568
        self.quant_method.create_weights(
            layer=self,
569
            input_size_per_partition=self.input_size_per_partition,
570
571
572
573
            output_partition_sizes=self.output_partition_sizes,
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
574
575
            weight_loader=(
                self.weight_loader_v2 if self.quant_method.__class__.__name__
576
                in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
577
578
579
580
581
582
583
584
585
586
        if bias:
            self.bias = Parameter(
                torch.empty(self.output_size_per_partition,
                            dtype=params_dtype))
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
        else:
            self.register_parameter("bias", None)
587
        self.update_param_tp_status()
588

589
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
590

591
        output_dim = getattr(param, "output_dim", None)
592

593
594
595
596
597
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
598
        is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
599

600
601
602
603
604
605
606
607
        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
608
609
            final_shape = list(loaded_weight.shape)
            if output_dim is not None:
610
611
612
                assert final_shape[output_dim] % self.tp_size == 0
                final_shape[output_dim] = (final_shape[output_dim] //
                                           self.tp_size)
613
            param.materialize(final_shape, dtype=loaded_weight.dtype)
614

615
        param_data = param.data
616
        if output_dim is not None and not is_sharded_weight:
617
618
619
620
            if not envs.VLLM_USE_NN or len(param_data.shape)==1 or is_quantization:
                shard_size = param_data.shape[output_dim] 
            else:
                shard_size = param_data.shape[int(not(output_dim))]
621
            start_idx = self.tp_rank * shard_size
622

623
624
            loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                 shard_size)
625
626
627
628
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)
629

630
631
632
        if envs.VLLM_USE_NN and not is_quantization:
            loaded_weight = loaded_weight.t()
            
633
634
635
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

636
637
    def weight_loader_v2(self, param: BasevLLMParameter,
                         loaded_weight: torch.Tensor):
638
639
640
641
642
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)
643
644
        param.load_column_parallel_weight(loaded_weight=loaded_weight)

645
    def forward(
646
647
        self,
        input_,
648
649
650
        rms_weight: Optional[torch.Tensor] = None,
        residual: Optional[torch.Tensor] = None,
        update_hd: Optional[bool] = True
651
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
        if envs.USE_FUSED_RMS_QUANT and rms_weight is not None:
            input_quant_args = None
            assert rms_weight is not None 
            i_q, _scales = lm_faster_rmsquant(input=input_,
                                        rms_weight=rms_weight,
                                        epsilon=self.eps,
                                        quant_dtype=torch.int8,
                                        residual=residual,
                                        update_input=update_hd)
            new_residual = residual
            input_quant_args = [i_q, _scales]
        
            bias = self.bias if not self.skip_bias_add else None
            
            assert self.quant_method is not None
            output_parallel = self.quant_method.apply(self, input_, bias, input_quant_args)
            if self.gather_output and self.tp_size > 1:
                output = tensor_model_parallel_all_gather(output_parallel)
            else:
                output = output_parallel
            output_bias = self.bias if self.skip_bias_add else None
            if not self.return_bias:
                return output
            return output, new_residual, output_bias
676
        else:
677
678
679
680
681
682
683
684
685
686
687
688
689
            bias = self.bias if not self.skip_bias_add else None
            # Matrix multiply.
            assert self.quant_method is not None
            output_parallel = self.quant_method.apply(self, input_, bias)
            if self.gather_output and self.tp_size > 1:
                # All-gather across the partitions.
                output = tensor_model_parallel_all_gather(output_parallel)
            else:
                output = output_parallel
            output_bias = self.bias if self.skip_bias_add else None
            if not self.return_bias:
                return output
            return output, output_bias
690

691
692
693
694
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size_per_partition}"
        s += f", bias={self.bias is not None}"
695
        s += f", tp_size={self.tp_size}"
696
697
698
        s += f", gather_output={self.gather_output}"
        return s

699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717

class MergedColumnParallelLinear(ColumnParallelLinear):
    """Packed linear layers with column parallelism.

    Similar to ColumnParallelLinear, but the weight matrix is concatenated
    along the output dimension. When the weight matrix is loaded, the
    different partitions are sharded separately.

    Args:
        input_size: input dimension of the linear layer.
        output_sizes: list of output dimensions of the linear layer.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make the output
                       available to all GPUs, otherwise, every GPU will have
                       its own output.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
718
        quant_config: Quantization configure.
719
720
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
721
        return_bias: If true, return bias together with outputs in forward pass.
722
723
        disable_tp: If true, all weights matrix won't be sharded, this layer
                    will be treated as a "Replicated" MergedLinear.
724
725
    """

726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
    def forward(
        self, input_,
        rms_weight: Optional[torch.Tensor] = None,
        residual: Optional[torch.Tensor] = None,
        update_hd: Optional[bool] = True
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
        if envs.USE_FUSED_RMS_QUANT and rms_weight is not None:
            input_quant_args = None
            assert residual is not None and rms_weight is not None 
            i_q, _scales = lm_faster_rmsquant(input=input_,
                                        rms_weight=rms_weight,
                                        epsilon=self.eps,
                                        quant_dtype=torch.int8,
                                        residual=residual,
                                        update_input=update_hd)
            
            new_residual = residual
            input_quant_args = [i_q, _scales]
            
            
            bias = self.bias if not self.skip_bias_add else None
            assert self.quant_method is not None
            output_parallel = self.quant_method.apply(self, input_, bias, input_quant_args)
            
            if self.gather_output:
                # All-gather across the partitions.
                output = tensor_model_parallel_all_gather(output_parallel)
            else:
                output = output_parallel
            output_bias = self.bias if self.skip_bias_add else None
            if not self.return_bias:
                return output
            return output, new_residual, output_bias
        else: # not USE_FUSED_RMS_QUANT
            bias = self.bias if not self.skip_bias_add else None

            assert self.quant_method is not None
            output_parallel = self.quant_method.apply(self, input_, bias)
            if self.gather_output and self.tp_size > 1:
                # All-gather across the partitions.
                output = tensor_model_parallel_all_gather(output_parallel)
            else:
                output = output_parallel
            output_bias = self.bias if self.skip_bias_add else None
            if not self.return_bias:
                return output
            return output, output_bias

774
775
776
777
778
779
780
781
782
    def __init__(
        self,
        input_size: int,
        output_sizes: list[int],
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
783
        eps: Optional[float] = 1e-6,
784
785
786
        prefix: str = "",
        *,
        return_bias: bool = True,
787
        disable_tp: bool = False,
788
    ):
789
        self.eps = eps
790
        self.output_sizes = output_sizes
791
792
793
794
        self.tp_size = (get_tensor_model_parallel_world_size()
                        if not disable_tp else 1)
        self.tp_rank = (get_tensor_model_parallel_rank()
                        if not disable_tp else 0)
795
796
797

        assert all(output_size % self.tp_size == 0
                   for output_size in output_sizes)
798
799
800
801
802
803
        super().__init__(input_size=input_size,
                         output_size=sum(output_sizes),
                         bias=bias,
                         gather_output=gather_output,
                         skip_bias_add=skip_bias_add,
                         params_dtype=params_dtype,
804
                         quant_config=quant_config,
805
                         prefix=prefix,
806
807
                         return_bias=return_bias,
                         disable_tp=disable_tp)
808
809
810
811
812

    def weight_loader(self,
                      param: Parameter,
                      loaded_weight: torch.Tensor,
                      loaded_shard_id: Optional[int] = None):
James Fleming's avatar
James Fleming committed
813

814
815
816
817
818
        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
819
820
821
822
823
824
825
826
            if loaded_shard_id is not None:
                param.data[loaded_shard_id].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
                param.shard_weight_type = {
                    i: loaded_weight.item()
                    for i, _ in enumerate(self.output_sizes)
                }
827
828
            return

829
        if is_gguf_weight:
830

831
            output_dim = getattr(param, "output_dim", None)
832
833
            shard_size = loaded_weight.size(output_dim) // self.tp_size
            start_idx = self.tp_rank * shard_size
834

835
836
837
838
839
840
841
            if loaded_shard_id is not None:
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
842

843
844
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
845
846
        # Special case for per-tensor scale to load scalar into fused array.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
847
        is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
848

849
        if loaded_shard_id is None:
850
851
            # Loaded weight is already fused on disk (mlp).
            # (e.g., Phi-3's gate_up_proj).
852
            if output_dim is None:
853
                if needs_scalar_to_array:
854
855
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
                        param_data, loaded_weight, 0)
856

857
858
859
860
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            current_shard_offset = 0
861
862
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)
863
            shard_offsets: list[tuple[int, int, int]] = []
864
865
866
867
868
            for i, output_size in enumerate(self.output_sizes):
                shard_offsets.append((i, current_shard_offset, output_size))
                current_shard_offset += output_size
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
869
                # Special case for Quantization.
870
871
872
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
873
874
                    shard_size = shard_size // param.packed_factor
                    shard_offset = shard_offset // param.packed_factor
875
                    # Special case for Marlin.
876
877
878
                    shard_size, shard_offset = adjust_marlin_shard(
                        param, shard_size, shard_offset)

879
880
881
                shard_size, shard_offset = adjust_bitblas_shard(
                    param, shard_size, shard_offset)

882
                if use_bitsandbytes_4bit:
883
884
885
886
887
888
889
890
891
                    index = list(itertools.accumulate([0] + self.output_sizes))
                    orig_offsets = {
                        str(i): (index[i], size)
                        for i, size in enumerate(self.output_sizes)
                    }
                    orig_offsets["total"] = (self.output_size, 0)
                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
                        param, orig_offsets, str(shard_id))

892
893
894
895
896
897
898
                loaded_weight_shard = loaded_weight.narrow(
                    output_dim, shard_offset, shard_size)
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id < len(self.output_sizes)
        if output_dim is not None:
899
900
901
            shard_offset = (sum(self.output_sizes[:loaded_shard_id]) //
                            self.tp_size)
            shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
902
            # Special case for quantization.
903
904
905
906
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
907
908
                shard_size = shard_size // param.packed_factor
                shard_offset = shard_offset // param.packed_factor
909
                # Special case for Marlin.
910
911
                shard_size, shard_offset = adjust_marlin_shard(
                    param, shard_size, shard_offset)
912
913
            shard_size, shard_offset = adjust_bitblas_shard(
                param, shard_size, shard_offset)
gaoqiong's avatar
gaoqiong committed
914

915
916
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)
917
918
919
920
921
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

922
            if use_bitsandbytes_4bit:
923
924
925
                shard_size = loaded_weight.shape[output_dim]
                shard_offset = loaded_weight.shape[output_dim] * \
                    loaded_shard_id
926
927
928
929
930
                    
            if not envs.VLLM_USE_NN or is_quantization:
                param_data = param_data.narrow(output_dim, shard_offset, shard_size)
            else:
                param_data = param_data.narrow(int(not(output_dim)), shard_offset, shard_size)
931

932
            start_idx = self.tp_rank * shard_size
933
            if not is_sharded_weight:
934
935
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)
936
937
938
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
939
940
                param_data, loaded_weight, loaded_shard_id)

941
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
942
943
944
945
946
947
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "MergedColumnParallelLinear, assume the weight is "
                    "the same for all partitions.")
948

949
950
951
        if envs.VLLM_USE_NN and not is_quantization:
            loaded_weight = loaded_weight.t()
            
gaoqiong's avatar
gaoqiong committed
952
953
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)
954

955
956
957
958
959
    def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
                                           loaded_weight: torch.Tensor):
        """
        Handle special case for models where MLP layers are already
        fused on disk. In this case, we have no shard id. This function
960
        determines the shard id by splitting these layers and then calls
961
962
963
964
965
966
967
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """

        current_shard_offset = 0
968
        shard_offsets: list[tuple[int, int, int]] = []
969
970
971
972
973
974
975
976
        for i, output_size in enumerate(self.output_sizes):
            shard_offsets.append((i, current_shard_offset, output_size))
            current_shard_offset += output_size

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
977
978
            if isinstance(param, (PackedColumnParameter, PackedvLLMParameter
                                  )) and param.packed_dim == param.output_dim:
979
980
                shard_size, shard_offset = \
                    param.adjust_shard_indexes_for_packing(
981
982
983
984
985
986
987
988
989
990
991
992
                    shard_size=shard_size, shard_offset=shard_offset)

            loaded_weight_shard = loaded_weight.narrow(param.output_dim,
                                                       shard_offset,
                                                       shard_size)
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

    def weight_loader_v2(self,
                         param: BasevLLMParameter,
                         loaded_weight: torch.Tensor,
                         loaded_shard_id: Optional[int] = None):
        if loaded_shard_id is None:
993
994
995
996
            if isinstance(param, PerTensorScaleParameter):
                param.load_merged_column_weight(loaded_weight=loaded_weight,
                                                shard_id=0)
                return
997
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
998
                param.load_merged_column_weight(loaded_weight=loaded_weight)
999
                return
1000
            # TODO: @dsikka - move to parameter.py
1001
1002
1003
1004
1005
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id < len(self.output_sizes)

1006
1007
        if isinstance(param, BlockQuantScaleParameter):
            assert self.quant_method is not None
1008
1009
1010
            # Assume the weight block size has been set by quant method
            assert hasattr(self, "weight_block_size")
            weight_block_size = self.weight_block_size
1011
1012
1013
1014
            assert weight_block_size is not None
            block_n, _ = weight_block_size[0], weight_block_size[1]
            shard_offset = (
                (sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) //
1015
                block_n) // self.tp_size
1016
            shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) //
1017
                          block_n // self.tp_size)
1018
        else:
1019
1020
1021
            shard_offset = sum(
                self.output_sizes[:loaded_shard_id]) // self.tp_size
            shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
1022
1023
1024
1025

        param.load_merged_column_weight(loaded_weight=loaded_weight,
                                        shard_id=loaded_shard_id,
                                        shard_offset=shard_offset,
1026
1027
                                        shard_size=shard_size,
                                        tp_rank=self.tp_rank)
1028

1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050

class QKVParallelLinear(ColumnParallelLinear):
    """Linear layers for the attention's QKV transformation.

    Linear layers for the linear transformation of the query, key, and value
    vectors in the attention layer. The weight matrix is concatenated along
    the output dimension. The layer is parallelized along the head dimension.
    When the number of key/value heads is smaller than the number of query
    heads (e.g., multi-query/grouped-query attention), the key/value head may
    be replicated while the query heads are partitioned.

    Args:
        hidden_size: input hidden state size of the transformer.
        head_size: size of each attention head.
        total_num_heads: total number of attention query heads.
        total_num_kv_heads: total number of attention key/value heads. If
                            None, assume total_num_kv_heads = total_num_heads.
        bias: If true, add bias.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
1051
        quant_config: Quantization configure.
1052
1053
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
1054
        return_bias: If true, return bias together with outputs in forward pass.
1055
        disable_tp: If true, weights matrix won't be sharded through tp rank.
1056
1057
    """

1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
    def __init__(
        self,
        hidden_size: int,
        head_size: int,
        total_num_heads: int,
        total_num_kv_heads: Optional[int] = None,
        bias: bool = True,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        *,
        return_bias: bool = True,
1071
        disable_tp: bool = False,
1072
    ):
1073
1074
1075
1076
1077
1078
1079
        self.hidden_size = hidden_size
        self.head_size = head_size
        self.total_num_heads = total_num_heads
        if total_num_kv_heads is None:
            total_num_kv_heads = total_num_heads
        self.total_num_kv_heads = total_num_kv_heads
        # Divide the weight matrix along the last dimension.
1080
1081
        tp_size = (get_tensor_model_parallel_world_size()
                   if not disable_tp else 1)
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
        self.num_heads = divide(self.total_num_heads, tp_size)
        if tp_size >= self.total_num_kv_heads:
            self.num_kv_heads = 1
            self.num_kv_head_replicas = divide(tp_size,
                                               self.total_num_kv_heads)
        else:
            self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
            self.num_kv_head_replicas = 1
        input_size = self.hidden_size
        output_size = (self.num_heads +
                       2 * self.num_kv_heads) * tp_size * self.head_size
1093
1094
1095
1096
        self.output_sizes = [
            self.num_heads * self.head_size * tp_size,  # q_proj
            self.num_kv_heads * self.head_size * tp_size,  # k_proj
            self.num_kv_heads * self.head_size * tp_size,  # v_proj 
James Fleming's avatar
James Fleming committed
1097
        ]
gaoqiong's avatar
gaoqiong committed
1098

1099
1100
1101
1102
1103
1104
        super().__init__(input_size=input_size,
                         output_size=output_size,
                         bias=bias,
                         gather_output=False,
                         skip_bias_add=skip_bias_add,
                         params_dtype=params_dtype,
1105
                         quant_config=quant_config,
1106
                         prefix=prefix,
1107
1108
                         return_bias=return_bias,
                         disable_tp=disable_tp)
1109

1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
    def _get_shard_offset_mapping(self, loaded_shard_id: str):
        shard_offset_mapping = {
            "q": 0,
            "k": self.num_heads * self.head_size,
            "v": (self.num_heads + self.num_kv_heads) * self.head_size,
            "total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size
        }
        return shard_offset_mapping.get(loaded_shard_id)

    def _get_shard_size_mapping(self, loaded_shard_id: str):
        shard_size_mapping = {
            "q": self.num_heads * self.head_size,
            "k": self.num_kv_heads * self.head_size,
            "v": self.num_kv_heads * self.head_size,
        }
        return shard_size_mapping.get(loaded_shard_id)

    def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
                                           loaded_weight: torch.Tensor):
        """
        Handle special case for models where QKV layers are already 
        fused on disk. In this case, we have no shard id. This function
1132
        determines the shard id by splitting these layers and then calls
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """
        shard_offsets = [
            # (shard_id, shard_offset, shard_size)
            ("q", 0, self.total_num_heads * self.head_size),
            ("k", self.total_num_heads * self.head_size,
             self.total_num_kv_heads * self.head_size),
            ("v",
             (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
             self.total_num_kv_heads * self.head_size),
        ]

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
1152
1153
            if isinstance(param, (PackedColumnParameter, PackedvLLMParameter
                                  )) and param.packed_dim == param.output_dim:
1154
1155
                shard_size, shard_offset = \
                    param.adjust_shard_indexes_for_packing(
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
                    shard_size=shard_size, shard_offset=shard_offset)

            loaded_weight_shard = loaded_weight.narrow(param.output_dim,
                                                       shard_offset,
                                                       shard_size)
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

    def weight_loader_v2(self,
                         param: BasevLLMParameter,
                         loaded_weight: torch.Tensor,
                         loaded_shard_id: Optional[str] = None):
        if loaded_shard_id is None:  # special case for certain models
1168
            if isinstance(param, PerTensorScaleParameter):
1169
1170
1171
                param.load_qkv_weight(loaded_weight=loaded_weight,
                                      shard_id=0,
                                      tp_rank=self.tp_rank)
1172
                return
1173
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
1174
1175
                param.load_qkv_weight(loaded_weight=loaded_weight,
                                      tp_rank=self.tp_rank)
1176
                return
1177
            # TODO: @dsikka - move to parameter.py
1178
1179
1180
1181
1182
1183
1184
1185
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id in ["q", "k", "v"]

        shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
        shard_size = self._get_shard_size_mapping(loaded_shard_id)

1186
1187
1188
        # Note(simon): This is needed for Qwen3's fp8 quantization.
        if isinstance(param, BlockQuantScaleParameter):
            assert self.quant_method is not None
1189
1190
1191
1192
            # Assume the weight block size has been set by quant method
            assert hasattr(self, "weight_block_size")
            weight_block_size = self.weight_block_size
            assert weight_block_size is not None
1193
1194
1195
1196
            block_n, _ = weight_block_size[0], weight_block_size[1]
            shard_offset = (shard_offset + block_n - 1) // block_n
            shard_size = (shard_size + block_n - 1) // block_n

1197
1198
1199
1200
        param.load_qkv_weight(loaded_weight=loaded_weight,
                              num_heads=self.num_kv_head_replicas,
                              shard_id=loaded_shard_id,
                              shard_offset=shard_offset,
1201
1202
                              shard_size=shard_size,
                              tp_rank=self.tp_rank)
1203

1204
1205
1206
1207
    def weight_loader(self,
                      param: Parameter,
                      loaded_weight: torch.Tensor,
                      loaded_shard_id: Optional[str] = None):
1208
1209
1210
1211
1212

        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
1213
        if is_gguf_weight_type:
1214
            idx_map = {"q": 0, "k": 1, "v": 2}
1215
1216
1217
1218
1219
1220
1221
1222
            if loaded_shard_id is not None:
                param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
                param.shard_weight_type = {
                    k: loaded_weight.item()
                    for k in idx_map
                }
1223
1224
            return

1225
1226
        if is_gguf_weight:
            output_dim = getattr(param, "output_dim", None)
1227
1228
            shard_size = loaded_weight.size(output_dim) // self.tp_size
            start_idx = self.tp_rank * shard_size
1229

1230
1231
1232
1233
1234
1235
1236
            if loaded_shard_id is not None:
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
1237

1238
1239
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
1240

1241
1242
        # Special case for per-tensor scales in fused case.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
1243
        is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
1244

1245
        if loaded_shard_id is None:
1246
1247
            # Loaded weight is already fused on disk (qkv).
            # (e.g., Phi-3's qkv_proj).
1248
            if output_dim is None:
1249
                if needs_scalar_to_array:
1250
1251
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
                        param_data, loaded_weight, 0)
1252

1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            shard_offsets = [
                # (shard_id, shard_offset, shard_size)
                ("q", 0, self.total_num_heads * self.head_size),
                ("k", self.total_num_heads * self.head_size,
                 self.total_num_kv_heads * self.head_size),
                ("v", (self.total_num_heads + self.total_num_kv_heads) *
                 self.head_size, self.total_num_kv_heads * self.head_size),
            ]
1264
1265
1266
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)

1267
1268
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
1269
                # Special case for Quantized Weights.
1270
1271
1272
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
1273
1274
                    shard_size = shard_size // param.packed_factor
                    shard_offset = shard_offset // param.packed_factor
1275

1276
                    # Special case for Marlin.
1277
1278
1279
                    shard_size, shard_offset = adjust_marlin_shard(
                        param, shard_size, shard_offset)

1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
                if use_bitsandbytes_4bit:
                    orig_qkv_offsets = {
                        "q": (0, self.total_num_heads * self.head_size),
                        "k": (self.total_num_heads * self.head_size,
                              self.total_num_kv_heads * self.head_size),
                        "v":
                        ((self.total_num_heads + self.total_num_kv_heads) *
                         self.head_size,
                         self.total_num_kv_heads * self.head_size),
                        "total":
                        ((self.total_num_heads + 2 * self.total_num_kv_heads) *
                         self.head_size, 0)
                    }

                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
                        param, orig_qkv_offsets, shard_id)

1297
1298
1299
1300
1301
1302
                loaded_weight_shard = loaded_weight.narrow(
                    output_dim, shard_offset, shard_size)
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id in ["q", "k", "v"]
1303
1304

        # If output dim is defined, use the default loading process.
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
        if output_dim is not None:
            if loaded_shard_id == "q":
                shard_offset = 0
                shard_size = self.num_heads * self.head_size
            elif loaded_shard_id == "k":
                shard_offset = self.num_heads * self.head_size
                shard_size = self.num_kv_heads * self.head_size
            elif loaded_shard_id == "v":
                shard_offset = (self.num_heads +
                                self.num_kv_heads) * self.head_size
                shard_size = self.num_kv_heads * self.head_size
1316
            # Special case for Quantized Weights.
1317
1318
1319
1320
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
1321
1322
                shard_size = shard_size // param.packed_factor
                shard_offset = shard_offset // param.packed_factor
1323

1324
                # Special case for Marlin.
1325
1326
1327
                shard_size, shard_offset = adjust_marlin_shard(
                    param, shard_size, shard_offset)

1328
1329
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)
1330
1331
1332
1333
1334
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

1335
            if use_bitsandbytes_4bit:
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
                orig_qkv_offsets = {
                    "q": (0, self.num_heads * self.head_size),
                    "k": (self.num_heads * self.head_size,
                          self.num_kv_heads * self.head_size),
                    "v":
                    ((self.num_heads + self.num_kv_heads) * self.head_size,
                     self.num_kv_heads * self.head_size),
                    "total":
                    ((self.num_heads + 2 * self.num_kv_heads) * self.head_size,
                     0)
                }
1347
                shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
1348
                    param, orig_qkv_offsets, loaded_shard_id)
gaoqiong's avatar
gaoqiong committed
1349

1350
1351
1352
1353
1354
1355
1356
            if not envs.VLLM_USE_NN or len(param_data.shape)==1 or is_quantization:
                param_data = param_data.narrow(output_dim, shard_offset,
                                               shard_size)
            else:
                param_data = param_data.narrow(int(not(output_dim)), shard_offset,
                                               shard_size)
                
zhuwenwen's avatar
zhuwenwen committed
1357
            if loaded_shard_id == "q":
1358
                shard_id = self.tp_rank
1359
            else:
1360
                shard_id = self.tp_rank // self.num_kv_head_replicas
1361
            start_idx = shard_id * shard_size
1362

1363
            if not is_sharded_weight:
1364
1365
1366
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)

1367
1368
1369
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
1370
                param_data, loaded_weight, loaded_shard_id)
1371
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
1372
1373
1374
1375
1376
1377
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "QKVParallelLinear, assume the weight is the same "
                    "for all partitions.")
gaoqiong's avatar
gaoqiong committed
1378

1379
1380
1381
        if envs.VLLM_USE_NN and not is_quantization:
            loaded_weight = loaded_weight.t()
            
gaoqiong's avatar
gaoqiong committed
1382
1383
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)
1384
1385


1386
@CustomOp.register("row_parallel_linear")
1387
class RowParallelLinear(LinearBase):
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
    """Linear layer with row parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its first dimension and X along its second dimension as:
               -   -
              | A_1 |
              | .   |
          A = | .   |        X = [X_1, ..., X_p]
              | .   |
              | A_p |
               -   -
    Arguments:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias. Note that bias is not parallelized.
        input_is_parallel: If true, we assume that the input is already
                           split across the GPUs and we do not split
                           again.
        skip_bias_add: This was added to enable performance optimization where
                       bias can be fused with other element-wise operations.
                       We skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
1410
1411
1412
        reduce_results: If true, call all-reduce on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y = X_iA_i
1413
        quant_config: Quantization configure.
1414
1415
1416
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.down_proj)
        return_bias: If true, return bias together with outputs in forward pass.
1417
        disable_tp: If true, weights matrix won't be sharded through tp rank.
1418
1419
    """

1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        input_is_parallel: bool = True,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        reduce_results: bool = True,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        *,
        return_bias: bool = True,
1433
        disable_tp: bool = False,
1434
    ):
1435
        # Divide the weight matrix along the first dimension.
1436
1437
1438
1439
        self.tp_rank = (get_tensor_model_parallel_rank()
                        if not disable_tp else 0)
        self.tp_size = (get_tensor_model_parallel_world_size()
                        if not disable_tp else 1)
1440
1441
1442
1443
        self.input_size_per_partition = divide(input_size, self.tp_size)
        self.output_size_per_partition = output_size
        self.output_partition_sizes = [output_size]

1444
1445
1446
1447
1448
1449
        super().__init__(input_size,
                         output_size,
                         skip_bias_add,
                         params_dtype,
                         quant_config,
                         prefix,
1450
1451
                         return_bias=return_bias,
                         disable_tp=disable_tp)
1452

1453
1454
1455
        self.input_is_parallel = input_is_parallel
        self.reduce_results = reduce_results

1456
        assert self.quant_method is not None
1457
1458
1459
        self.quant_method.create_weights(
            layer=self,
            input_size_per_partition=self.input_size_per_partition,
1460
            output_partition_sizes=self.output_partition_sizes,
1461
1462
1463
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
1464
1465
            weight_loader=(
                self.weight_loader_v2 if self.quant_method.__class__.__name__
1466
                in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
1467
1468
1469
1470
1471
1472
        if not reduce_results and (bias and not skip_bias_add):
            raise ValueError("When not reduce the results, adding bias to the "
                             "results can lead to incorrect results")

        if bias:
            self.bias = Parameter(
1473
                torch.empty(self.output_size, dtype=params_dtype))
1474
1475
1476
1477
1478
1479
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
        else:
            self.register_parameter("bias", None)
1480

1481
        self.update_param_tp_status()
1482
1483
1484

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        input_dim = getattr(param, "input_dim", None)
1485
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1486
1487
1488
1489
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
1490
1491
1492
1493
1494
1495
1496
1497
1498

        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
1499
1500
            weight_shape = list(loaded_weight.shape)
            if input_dim:
1501
1502
                weight_shape[input_dim] = (weight_shape[input_dim] //
                                           self.tp_size)
1503
            param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
1504
1505
            
        is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
1506

1507
        param_data = param.data
1508
        if input_dim is not None and not is_sharded_weight:
1509
1510
1511
1512
            if not envs.VLLM_USE_NN or is_quantization:
                shard_size = param_data.shape[input_dim]
            else:
                shard_size = param_data.shape[int(not(input_dim))]
1513
            start_idx = self.tp_rank * shard_size
1514
1515
            loaded_weight = loaded_weight.narrow(input_dim, start_idx,
                                                 shard_size)
1516

1517
1518
1519
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
1520
1521
            loaded_weight = loaded_weight.reshape(1)

1522
1523
1524
        if envs.VLLM_USE_NN and not is_quantization:
            loaded_weight = loaded_weight.t()
            
1525
1526
1527
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

1528
1529
    def weight_loader_v2(self, param: BasevLLMParameter,
                         loaded_weight: torch.Tensor):
1530
1531
1532
1533
1534
1535
1536

        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)

1537
1538
        param.load_row_parallel_weight(loaded_weight=loaded_weight)

1539
    def forward(
1540
1541
        self,
        input_,
1542
        use_fused_silu_mul_quant: Optional[bool] = False
1543
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
1544
1545
1546
1547
1548
        if self.input_is_parallel:
            input_parallel = input_
        else:
            splitted_input = split_tensor_along_last_dim(
                input_, num_partitions=self.tp_size)
1549
            input_parallel = splitted_input[self.tp_rank].contiguous()
1550
1551

        # Matrix multiply.
1552
        assert self.quant_method is not None
1553
1554
1555
        # Only fuse bias add into GEMM for rank 0 (this ensures that
        # bias will not get added more than once in TP>1 case)
        bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
1556
1557
1558
1559
1560
1561
        if use_fused_silu_mul_quant:
            xq, xs = lm_fuse_silu_mul_quant(input_parallel)
            
            silu_quant_args = [xq, xs]
            output_parallel = self.quant_method.apply(self,
                                                      input_parallel,
1562
                                                      bias_,
1563
1564
                                                      silu_quant_args=silu_quant_args)
        else:
1565
            output_parallel = self.quant_method.apply(self, input_parallel, bias_)
1566

1567
        if self.reduce_results and self.tp_size > 1:
zhuwenwen's avatar
zhuwenwen committed
1568
            output = tensor_model_parallel_all_reduce(output_parallel)
1569
        else:
1570
1571
1572
            output = output_parallel

        output_bias = self.bias if self.skip_bias_add else None
1573

1574
1575
        if not self.return_bias:
            return output
1576
        return output, output_bias
1577
1578

    def extra_repr(self) -> str:
1579
        s = f"in_features={self.input_size_per_partition}"
1580
1581
1582
1583
1584
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={self.tp_size}"
        s += f", reduce_results={self.reduce_results}"
        return s
1585
1586


1587
@CustomOp.register("qkv_cross_parallel_linear")
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
class QKVCrossParallelLinear(LinearBase):
    """Linear layers for efficient cross-attention's QKV transformation.

    Args:
        hidden_size: input hidden state size of the transformer.
        head_size: size of each attention head.
        total_num_heads: total number of attention query heads.
        total_num_kv_heads: total number of attention key/value heads. If
                            None, assume total_num_kv_heads = total_num_heads.
        bias: If true, add bias.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
        quant_config: Quantization configure.
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
    """
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616

    def __init__(self,
                 hidden_size: int,
                 head_size: int,
                 total_num_heads: int,
                 total_num_kv_heads: Optional[int] = None,
                 bias: bool = True,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
        # input_size and output_size are not used, just for alignment
        input_size = hidden_size
        output_size = (total_num_heads + (total_num_kv_heads or 0)) * head_size
        super().__init__(input_size=input_size,
                         output_size=output_size,
                         skip_bias_add=skip_bias_add,
                         params_dtype=params_dtype,
                         quant_config=quant_config,
                         prefix=prefix)

        self.quant_config = quant_config

1629
        # Empty placeholders for loading as a single module.
1630
1631
1632
1633
1634
1635
1636
1637
1638
        placeholder_size = 0
        assert self.quant_method is not None
        self.quant_method.create_weights(self,
                                         placeholder_size, [placeholder_size],
                                         placeholder_size,
                                         placeholder_size,
                                         self.params_dtype,
                                         weight_loader=self.weight_loader)

1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
        # Use a dictionary to avoid submodules parameters auto-registration:
        # drop-in replacement for a `QKVParallelLinear` module.
        self.proj = dict()
        self.proj["q_proj_decoder"] = ColumnParallelLinear(
            input_size=hidden_size,
            output_size=total_num_heads * head_size,
            bias=bias,
            quant_config=quant_config,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            prefix=f"{prefix}.q_proj_decoder")

        self.proj["kv_proj_encoder"] = QKVParallelLinear(
            hidden_size=hidden_size,
            head_size=head_size,
            total_num_heads=0,
            total_num_kv_heads=total_num_kv_heads,
            bias=bias,
            quant_config=quant_config,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            prefix=f"{prefix}.kv_proj_encoder")

        # `kv_proj_encoder.num_kv_heads` accounts for sharding with tp>1.
1663
        self.q_size = self.q_proj_decoder.output_size_per_partition
1664
1665
1666
1667
1668
        self.kv_size = self.kv_proj_encoder.num_kv_heads * head_size

        if bias:
            self.bias = torch.nn.Parameter()
            set_weight_attrs(self.bias, {
1669
                "output_dim": 0,
1670
                "weight_loader": self.weight_loader_v1,
1671
            })
1672
1673
        else:
            self.bias = None
1674

1675
1676
1677
1678
1679
    def process_weights_after_loading(self):
        for layer in self.proj.values():
            if self.quant_method is not None:
                self.quant_method.process_weights_after_loading(layer)

1680
    @property
1681
1682
1683
    def q_proj_decoder(self) -> ColumnParallelLinear:
        layer = self.proj["q_proj_decoder"]
        for name, param in self.named_parameters():
1684
1685
1686
1687
1688
            target_param = getattr(layer, name, None)
            if target_param is not None:
                self.sync_weight_attrs(param,
                                       target_param,
                                       mode="q_proj_decoder")
1689
        return layer
1690
1691

    @property
1692
1693
1694
    def kv_proj_encoder(self) -> QKVParallelLinear:
        layer = self.proj["kv_proj_encoder"]
        for name, param in self.named_parameters():
1695
1696
1697
1698
1699
            target_param = getattr(layer, name, None)
            if target_param is not None:
                self.sync_weight_attrs(param,
                                       target_param,
                                       mode="kv_proj_encoder")
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
        return layer

    def sync_weight_attrs(
        self,
        src_param: nn.Parameter,
        tgt_param: nn.Parameter,
        mode: Literal["q_proj_decoder", "kv_proj_encoder"],
    ):
        missing_attrs_dict = {
            k: getattr(src_param, k)
1710
1711
            for k in (set(vars(src_param).keys()) -
                      set(vars(tgt_param).keys()))
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
        }
        # TODO(Isotr0py): handle bitsandbytes 8bit
        use_bitsandbytes_4bit = getattr(src_param, "use_bitsandbytes_4bit",
                                        False)
        if (missing_attrs_dict and use_bitsandbytes_4bit):
            q_proj_attrs, kv_proj_attrs = left_shift_bitsandbytes_4bit_shard(
                missing_attrs_dict)
            if mode == "q_proj_decoder":
                set_weight_attrs(tgt_param, q_proj_attrs)
            elif mode == "kv_proj_encoder":
                set_weight_attrs(tgt_param, kv_proj_attrs)
        else:
            set_weight_attrs(tgt_param, missing_attrs_dict)
1725

1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
    def _is_same_param(
        self,
        src_param: torch.nn.Parameter,
        map_param: torch.nn.Parameter,
    ) -> bool:
        """Check if two parameters are exactly pointing to same things."""
        # ignore weight_loader because it's always different
        key_to_ignore = ["weight_loader", "_weight_loader"]
        has_same_type_name = type(src_param) is type(map_param)
        src_param_attrs = {
            k: v
            for k, v in src_param.__dict__.items() if k not in key_to_ignore
        }
        map_param_attrs = {
            k: v
            for k, v in map_param.__dict__.items() if k not in key_to_ignore
        }
        has_same_attrs = src_param_attrs == map_param_attrs
        return has_same_type_name and has_same_attrs

    def select_proj_params(
        self,
        layer: nn.Module,
        param: nn.Parameter,
    ) -> nn.Parameter:
        """
        Given the placeholder param, 
        return the corresponding param in the proj layers.
        """
        target_param_list = [
            v for _, v in layer.named_parameters()
            if self._is_same_param(param, v)
        ]
        assert len(target_param_list) == 1
        target_param = target_param_list[0]
        return target_param

    def forward(  # type: ignore[override]
        self,
        decoder_hidden_states: torch.Tensor,
        encoder_hidden_states: torch.Tensor,
    ) -> tuple[torch.Tensor, ...]:
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
        q, _ = self.q_proj_decoder(decoder_hidden_states)
        if encoder_hidden_states is None:
            # Encoder KV already cached.
            k = None
            v = None
        else:
            # Prefill phase, encoder KV cached here.
            kv_enc, _ = self.kv_proj_encoder(encoder_hidden_states)
            # Split kv in half
            k, v = kv_enc.split(self.kv_size, dim=-1)
        return q, k, v

1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
    def weight_loader_v1(self,
                         param: torch.nn.Parameter,
                         loaded_weight: torch.Tensor,
                         loaded_shard_id: Optional[str] = None):
        # just like all other parameters, does not yet
        # support loading bias with weight_loader_v2
        layer = (self.q_proj_decoder
                 if loaded_shard_id == "q" else self.kv_proj_encoder)
        target_param = self.select_proj_params(layer, param)
        shard_id_args = (loaded_shard_id, ) if loaded_shard_id != "q" else ()
        layer.weight_loader(target_param, loaded_weight, *shard_id_args)

1792
1793
1794
1795
1796
1797
1798
1799
    def weight_loader(self,
                      param: torch.nn.Parameter,
                      loaded_weight: torch.Tensor,
                      loaded_shard_id: Optional[str] = None):
        layer = (self.q_proj_decoder
                 if loaded_shard_id == "q" else self.kv_proj_encoder)
        target_param = self.select_proj_params(layer, param)
        shard_id_args = (loaded_shard_id, ) if loaded_shard_id != "q" else ()
1800
1801
1802
1803
        if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED:
            layer.weight_loader_v2(target_param, loaded_weight, *shard_id_args)
        else:
            layer.weight_loader(target_param, loaded_weight, *shard_id_args)
1804
1805
1806

    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
1807
        s += f", q_size={self.q_size}"
1808
1809
1810
1811
        s += f", kv_size={self.kv_size}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={get_tensor_model_parallel_world_size()}"
        s += ", gather_output=False"
1812
        return s