linear.py 78.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
from abc import abstractmethod
6
from typing import Any, Literal, Optional, Union
7
import vllm.envs as envs
8
import torch
9
import torch.nn as nn
10
from torch.nn.parameter import Parameter, UninitializedParameter
11

12
13
14
15
16
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size,
                              split_tensor_along_last_dim,
                              tensor_model_parallel_all_gather,
                              tensor_model_parallel_all_reduce)
17
from vllm.logger import init_logger
18
from vllm.model_executor.custom_op import CustomOp
19
20
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig, QuantizeMethodBase)
21
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
22
# yapf: disable
23
from vllm.model_executor.parameter import (BasevLLMParameter,
24
                                           BlockQuantScaleParameter,
25
                                           ModelWeightParameter,
26
                                           PackedColumnParameter,
27
                                           PackedvLLMParameter,
28
29
                                           PerTensorScaleParameter,
                                           RowvLLMParameter)
30
# yapf: enable
31
from vllm.model_executor.utils import set_weight_attrs
32
from vllm.platforms import current_platform
33
from vllm.utils import GiB_bytes
gaoqiong's avatar
gaoqiong committed
34

zhuwenwen's avatar
zhuwenwen committed
35
import os
36
from vllm.model_executor.utils import gemm_bank_conf
37

38
39
40
41
42
if envs.USE_FUSED_RMS_QUANT:
    try:
        from lmslim.quantize.quant_ops import lm_faster_rmsquant
    except Exception as e:
        print(f"Error: Import fused rmsquant error: {e}") 
43
44
45
46
47
48
49
        
if envs.USE_FUSED_SILU_MUL_QUANT:        
    try:
        # from lightop import fuse_silu_mul_quant
        from lmslim.quantize.quant_ops import lm_fuse_silu_mul_quant
    except Exception as e:
        print(f"Error: Import fused silu_mul_qunat error: {e}")
50

51
52
logger = init_logger(__name__)

53
WEIGHT_LOADER_V2_SUPPORTED = [
54
    "UnquantizedLinearMethod",
55
    "CompressedTensorsLinearMethod",
56
    "CompressedTensorsLinearTransformMethod",
57
58
    "BitBLASLinearMethod",
    "GPTQBitBLASLinearMethod",
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
    "AWQMarlinLinearMethod",
    "AWQLinearMethod",
    "GPTQMarlinLinearMethod",
    "Fp8LinearMethod",
    "MarlinLinearMethod",
    "GPTQMarlin24LinearMethod",
    "TPUInt8LinearMethod",
    "GPTQLinearMethod",
    "FBGEMMFp8LinearMethod",
    "ModelOptFp8LinearMethod",
    "IPEXAWQLinearMethod",
    "IPEXGPTQLinearMethod",
    "HQQMarlinMethod",
    "QuarkLinearMethod",
    "ModelOptNvFp4LinearMethod",
74
    "PetitNvFp4LinearMethod",
zhuwenwen's avatar
zhuwenwen committed
75
    "BlockInt8LinearMethod",
76
]
77

78

79
80
81
82
83
84
85
86
87
def adjust_bitblas_shard(param, shard_size, shard_offset):
    bitblas_tile_size = getattr(param, "bitblas_tile_size", None)
    if bitblas_tile_size is not None:
        return (shard_size // bitblas_tile_size,
                shard_offset // bitblas_tile_size)

    return shard_size, shard_offset


88
89
90
91
92
93
94
95
def adjust_marlin_shard(param, shard_size, shard_offset):
    marlin_tile_size = getattr(param, "marlin_tile_size", None)
    if marlin_tile_size is None:
        return shard_size, shard_offset

    return shard_size * marlin_tile_size, shard_offset * marlin_tile_size


96
def adjust_bitsandbytes_4bit_shard(param: Parameter,
97
98
                                   shard_offsets: dict[str, tuple[int, int]],
                                   loaded_shard_id: str) -> tuple[int, int]:
99
100
    """Adjust the quantization offsets and sizes for BitsAndBytes sharding."""

101
102
    total, _ = shard_offsets["total"]
    orig_offset, orig_size = shard_offsets[loaded_shard_id]
103
104
105
106
107
108
109
110

    quantized_total = param.data.shape[0]
    quantized_offset = orig_offset * quantized_total // total
    quantized_size = orig_size * quantized_total // total

    return quantized_size, quantized_offset


111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
    """For fused modules (QKV and MLP) we have an array of length
    N that holds 1 scale for each "logical" matrix. So the param
    is an array of length N. The loaded_weight corresponds to 
    one of the shards on disk. Here, we slice the param based on 
    the shard_id for loading.
    """
    qkv_idxs = {"q": 0, "k": 1, "v": 2}

    if isinstance(shard_id, str):
        shard_id = qkv_idxs[shard_id]
    elif not isinstance(shard_id, int):
        raise ValueError(f"Unknown Shard Id {shard_id}")

    # AutoFP8 scales do not have a shape
    # compressed-tensors scales do have a shape
    if len(loaded_weight.shape) != 0:
        assert loaded_weight.shape[0] == 1
        loaded_weight = loaded_weight[0]

131
132
133
134
    if envs.VLLM_USE_NN:
        return param[shard_id], loaded_weight.t()
    else:
        return param[shard_id], loaded_weight
135
136


137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
# TODO(Isotr0py): We might need a more flexible structure to handle
# bitsandbytes shard offsets.
def left_shift_bitsandbytes_4bit_shard(bnb_weight_attrs: dict[str, Any]):
    """
    Separate the BitsAndBytes 4-bit shard.

    For example, given bnb weight attributes as below:
    {
        'bnb_shard_offsets': array([0, 4, 8, 16]), 
        'bnb_quant_state': {0: ..., 1: ..., 2: ...},
    }

    The function will return:
    {
        'bnb_shard_offsets': array([0, 4]), 
        'bnb_quant_state': {0: ...},
    }
    and
    {
        'bnb_shard_offsets': array([0, 4, 12]),
        'bnb_quant_state': {0: ..., 1: ...},
    }
    """
    shard_offsets = bnb_weight_attrs["bnb_shard_offsets"]
    offset_l = shard_offsets[:2]
    offset_r = shard_offsets[1:] - shard_offsets[1]
    quant_state_l = {0: bnb_weight_attrs["bnb_quant_state"][0]}
    quant_state_r = {
        i - 1: bnb_weight_attrs["bnb_quant_state"][i]
        for i in range(1,
                       len(shard_offsets) - 1)
    }
    left = dict(bnb_shard_offsets=offset_l, bnb_quant_state=quant_state_l)
    right = dict(bnb_shard_offsets=offset_r, bnb_quant_state=quant_state_r)
    return left, right


174
class LinearMethodBase(QuantizeMethodBase):
175
176
177
    """Base class for different (maybe quantized) linear methods."""

    @abstractmethod
178
179
    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
180
                       output_partition_sizes: list[int], input_size: int,
181
182
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
183
184
        """Create weights for a linear layer. 
           The weights will be set as attributes of the layer.
185

186
187
188
189
190
191
192
193
194
195
        Args:
            layer: The layer that is using the LinearMethodBase factory.
            input_size_per_partition: Size of the weight input dim on rank X.
            output_partition_sizes: Sizes of the output dim of each logical 
                weight on rank X. E.g., output_partition_sizes for QKVLinear
                is a list contains the width of Wq, Wk, Wv on rank X.
            input_size: Size of the input dim of the weight across all ranks.
            output_size: Size of the output dim of the weight across all ranks.
            params_dtype: Datatype of the parameters.
        """
196
197
198
        raise NotImplementedError

    @abstractmethod
199
200
201
202
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
203
204
        """Apply the weights in layer to the input tensor.
        Expects create_weights to have been called before on the layer."""
205
206
207
208
        raise NotImplementedError


class UnquantizedLinearMethod(LinearMethodBase):
209
    """Linear method without quantization."""
210
211
    
    def __init__(self):
zhuwenwen's avatar
zhuwenwen committed
212
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
213
        self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
gaoqiong's avatar
gaoqiong committed
214
        
215
216
    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
217
                       output_partition_sizes: list[int], input_size: int,
218
219
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
220
221
222
223
224
        # This method creates unquantized linear weights.
        # The weights are not quantized, and they are not sharded.
        # The amount of memory allocated for the weights is
        # sum(output_partition_sizes) * input_size_per_partition.
        try:
225
            weight_loader = extra_weight_attrs.pop("weight_loader")
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
            if envs.VLLM_USE_NN:
                weight = ModelWeightParameter(data=torch.empty(
                    input_size_per_partition,
                    sum(output_partition_sizes),
                    dtype=params_dtype),
                                            input_dim=1,
                                            output_dim=0,
                                            weight_loader=weight_loader)
            else:
                weight = ModelWeightParameter(data=torch.empty(
                    sum(output_partition_sizes),
                    input_size_per_partition,
                    dtype=params_dtype),
                                            input_dim=1,
                                            output_dim=0,
                                            weight_loader=weight_loader)
242
243
244
245
246
247
248
249
250
251
252
253
        except torch.cuda.OutOfMemoryError as e:
            logger.error("Failed to create unquantized linear weights: %s", e)
            if torch.cuda.is_available():
                logger.debug("CUDA device: %s", torch.cuda.current_device())
                logger.debug("Allocated: %.2f GiB",
                             torch.cuda.memory_allocated() / GiB_bytes)
                logger.debug("Reserved: %.2f GiB",
                             torch.cuda.memory_reserved() / GiB_bytes)
            raise RuntimeError(
                "Failed to create unquantized linear weights. "
                "This may be caused by insufficient memory to allocate "
                "the weight.") from e
254

255
256
        layer.register_parameter("weight", weight)
        set_weight_attrs(weight, extra_weight_attrs)
257

258
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
259
260
261
262
        if current_platform.is_cpu():
            from vllm.model_executor.layers.utils import (
                dispatch_cpu_unquantized_gemm)
            dispatch_cpu_unquantized_gemm(layer, remove_weight=True)
263

264
265
266
267
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
zhuwenwen's avatar
zhuwenwen committed
268
        if self.use_llama_nn:
zhuwenwen's avatar
zhuwenwen committed
269
270
            # if os.environ['GEMM_PAD'] == '1' and gemm_bank_conf(layer.weight.shape[1] - 32):
            #     layer.weight = layer.weight[:,:-32]
zhuwenwen's avatar
zhuwenwen committed
271
            if bias is not None:
zhuwenwen's avatar
zhuwenwen committed
272
                if len(x.shape) == 2: 
273
                    return torch.addmm(bias, x, layer.weight)
zhuwenwen's avatar
zhuwenwen committed
274
                else:
275
                    return torch.matmul(x, layer.weight) + bias
zhuwenwen's avatar
zhuwenwen committed
276
            else:
277
                return torch.matmul(x, layer.weight)
zhuwenwen's avatar
zhuwenwen committed
278
        else:
zhuwenwen's avatar
zhuwenwen committed
279
280
281
282
283
284
285
286
287
288
            # if envs.VLLM_USE_NN and x.shape[-1] == layer.weight.shape[0]:
            #     return dispatch_unquantized_gemm()(layer, x, layer.weight.t(), bias)
            if envs.VLLM_USE_NN:
                if bias is not None:
                    if len(x.shape) == 2: 
                        return torch.addmm(bias, x, layer.weight)
                    else:
                        return torch.matmul(x, layer.weight) + bias
                else:
                    return torch.matmul(x, layer.weight)
289
            else:
zhuwenwen's avatar
zhuwenwen committed
290
                return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
291

292

293
class LinearBase(CustomOp):
294
    """Base linear layer.
295
296
297
298
299
300

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
301
        quant_config: Quantization configure.
302
        prefix: Prefix for parameter names.
303
        return_bias: If true, return bias together with outputs in forward pass.
304
        disable_tp: If true, tensor parallelism will be disabled for this layer.
305
306
307
308
309
310
311
312
    """

    def __init__(
        self,
        input_size: int,
        output_size: int,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
313
        quant_config: Optional[QuantizationConfig] = None,
314
        prefix: str = "",
315
316
        *,
        return_bias: bool = True,
317
        disable_tp: bool = False,
318
319
320
321
322
323
324
325
326
327
    ):
        super().__init__()

        # Keep input parameters
        self.input_size = input_size
        self.output_size = output_size
        self.skip_bias_add = skip_bias_add
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype
328
329
        self.quant_config = quant_config
        self.prefix = prefix
330
        if quant_config is None:
331
332
            self.quant_method: Optional[
                QuantizeMethodBase] = UnquantizedLinearMethod()
333
        else:
334
335
            self.quant_method = quant_config.get_quant_method(self,
                                                              prefix=prefix)
336
        self.return_bias = return_bias
337
338
339
340
341
342
        self.disable_tp = disable_tp
        self.tp_rank = (get_tensor_model_parallel_rank()
                        if not disable_tp else 0)
        self.tp_size = (get_tensor_model_parallel_world_size()
                        if not disable_tp else 1)

343
    def update_param_tp_status(self):
344
345
346
347
        for param in self.parameters():
            if isinstance(param, BasevLLMParameter):
                param.tp_rank = self.tp_rank
                param.tp_size = self.tp_size
348
349


350
@CustomOp.register("replicated_linear")
351
352
353
354
355
356
357
358
359
360
class ReplicatedLinear(LinearBase):
    """Replicated linear layer.

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        bias: If true, add bias.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
        quant_config: Quantization configure.
361
362
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
363
        return_bias: If true, return bias together with outputs in forward pass.
364
        disable_tp: Take no effect for replicated linear layers.
365
366
    """

367
368
369
370
371
372
373
374
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
375
        eps: Optional[float] = 1e-6,
376
377
378
        prefix: str = "",
        *,
        return_bias: bool = True,
379
        disable_tp: bool = False,
380
    ):
381
382
383
384
385
386
        # If MergedReplicatedLinear, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = self.output_sizes
        else:
            self.output_partition_sizes = [output_size]

387
388
389
390
391
        super().__init__(input_size,
                         output_size,
                         skip_bias_add,
                         params_dtype,
                         quant_config,
392
                         prefix=prefix,
393
394
                         return_bias=return_bias,
                         disable_tp=disable_tp)
395
        self.eps = eps
396

397
398
        # All the linear layer supports quant method.
        assert self.quant_method is not None
399
        self.quant_method.create_weights(self,
400
401
                                         self.input_size,
                                         self.output_partition_sizes,
402
403
404
                                         self.input_size,
                                         self.output_size,
                                         self.params_dtype,
405
                                         weight_loader=self.weight_loader)
406

407
408
        if bias:
            self.bias = Parameter(
409
                torch.empty(self.output_size, dtype=self.params_dtype))
410
411
412
413
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
414
415
        else:
            self.register_parameter("bias", None)
416
        self.is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
417

418
419
420
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        # If the weight on disk does not have a shape, give it one
        # (such scales for AutoFp8).
421
422
423
424
425
426
427
428
429
430
431
        # Special case for GGUF

        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
            param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)

432
433
434
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)

435
        if envs.VLLM_USE_NN and not self.is_quantization:
436
437
            loaded_weight = loaded_weight.t()
            
438
439
440
        assert param.size() == loaded_weight.size(), (
            f"Tried to load weights of size {loaded_weight.size()}"
            f"to a parameter of size {param.size()}")
441
442
        param.data.copy_(loaded_weight)

443
    def forward(
444

445
        self,
446
447
448
449
450
        input_: torch.Tensor,
        rms_weight: Optional[torch.Tensor] = None,
        residual: Optional[torch.Tensor] = None,
        quant_args: Optional[list] = None,
        update_hd: Optional[bool] = True
451
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
        if envs.USE_FUSED_RMS_QUANT and (rms_weight is not None or quant_args is not None):
            if quant_args is not None:
                input_quant_args = quant_args
            
                bias = self.bias if not self.skip_bias_add else None
                assert self.quant_method is not None
                output = self.quant_method.apply(self, input_, bias, input_quant_args)
                output_bias = self.bias if self.skip_bias_add else None
                if not self.return_bias:
                    return output
                return output, output_bias

            else:
                i_q, _scales = lm_faster_rmsquant(input=input_,
                                                  rms_weight=rms_weight,
                                                  epsilon=self.eps,
                                                  quant_dtype=torch.int8,
                                                  residual=residual,
                                                  update_input=update_hd
                                                  )
            
                new_residual = residual
                input_quant_args = [i_q, _scales]
                
                bias = self.bias if not self.skip_bias_add else None
                assert self.quant_method is not None
                output = self.quant_method.apply(self, input_, bias, input_quant_args)
                output_bias = self.bias if self.skip_bias_add else None
                if not self.return_bias:
                    return output
                return output, new_residual, output_bias, input_quant_args

        else:
            bias = self.bias if not self.skip_bias_add else None
            assert self.quant_method is not None
            output = self.quant_method.apply(self, input_, bias)
            output_bias = self.bias if self.skip_bias_add else None
            if not self.return_bias:
                return output
            return output, output_bias
492

493
494
495
496
497
498
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        return s

499

500
@CustomOp.register("column_parallel_linear")
501
class ColumnParallelLinear(LinearBase):
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
    """Linear layer with column parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its second dimension as A = [A_1, ..., A_p].

    Args:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y_i = XA_i
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
518
        quant_config: Quantization configure.
James Fleming's avatar
James Fleming committed
519
520
        output_sizes: list of output sizes packed into one output, like for QKV
                       the list would be size 3.
521
        prefix: The name of the layer in the state dict, including all parents
522
                        (e.g. model.layers.0.qkv_proj) 
523
524
        return_bias: If true, return bias together with outputs in forward pass.
        disable_tp: If true, weights matrix won't be sharded through tp rank.
525
526
    """

527
528
529
530
531
532
533
534
535
536
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
        output_sizes: Optional[list[int]] = None,
537
        eps: Optional[float] = 1e-6,
538
539
540
        prefix: str = "",
        *,
        return_bias: bool = True,
541
        disable_tp: bool = False,
542
    ):
543
        # Divide the weight matrix along the last dimension.
544
545
546
547
        self.tp_rank = (get_tensor_model_parallel_rank()
                        if not disable_tp else 0)
        self.tp_size = (get_tensor_model_parallel_world_size()
                        if not disable_tp else 1)
548
549
        self.input_size_per_partition = input_size
        self.output_size_per_partition = divide(output_size, self.tp_size)
550
551
552
553
        self.output_partition_sizes = [self.output_size_per_partition]
        # If QKV or MergedColumn, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = [
554
                divide(output_size, self.tp_size)
555
556
557
                for output_size in self.output_sizes
            ]

558
559
560
561
562
563
        super().__init__(input_size,
                         output_size,
                         skip_bias_add,
                         params_dtype,
                         quant_config,
                         prefix,
564
565
                         return_bias=return_bias,
                         disable_tp=disable_tp)
566

567
        self.eps = eps
568
569
        self.gather_output = gather_output

James Fleming's avatar
James Fleming committed
570
571
        if output_sizes is None:
            output_sizes = [output_size]
572

573
        assert self.quant_method is not None
574
575
        self.quant_method.create_weights(
            layer=self,
576
            input_size_per_partition=self.input_size_per_partition,
577
578
579
580
            output_partition_sizes=self.output_partition_sizes,
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
581
582
            weight_loader=(
                self.weight_loader_v2 if self.quant_method.__class__.__name__
583
                in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
584
585
586
587
588
589
590
591
592
593
        if bias:
            self.bias = Parameter(
                torch.empty(self.output_size_per_partition,
                            dtype=params_dtype))
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
        else:
            self.register_parameter("bias", None)
594
        self.update_param_tp_status()
595
        self.is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
596

597
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
598

599
        output_dim = getattr(param, "output_dim", None)
600

601
602
603
604
605
606
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

607
608
609
610
611
612
613
614
        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
615
616
            final_shape = list(loaded_weight.shape)
            if output_dim is not None:
617
618
619
                assert final_shape[output_dim] % self.tp_size == 0
                final_shape[output_dim] = (final_shape[output_dim] //
                                           self.tp_size)
620
            param.materialize(final_shape, dtype=loaded_weight.dtype)
621

622
        param_data = param.data
623
        if output_dim is not None and not is_sharded_weight:
624
            if not envs.VLLM_USE_NN or len(param_data.shape)==1 or self.is_quantization:
625
626
627
                shard_size = param_data.shape[output_dim] 
            else:
                shard_size = param_data.shape[int(not(output_dim))]
628
            start_idx = self.tp_rank * shard_size
629

630
631
            loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                 shard_size)
632
633
634
635
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)
636

637
        if envs.VLLM_USE_NN and not self.is_quantization:
638
639
            loaded_weight = loaded_weight.t()
            
640
641
642
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

643
644
    def weight_loader_v2(self, param: BasevLLMParameter,
                         loaded_weight: torch.Tensor):
645
646
647
648
649
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)
650
        param.load_column_parallel_weight(loaded_weight=loaded_weight if not envs.VLLM_USE_NN or self.is_quantization else loaded_weight.t())
651

652
    def forward(
653
654
        self,
        input_,
655
656
657
        rms_weight: Optional[torch.Tensor] = None,
        residual: Optional[torch.Tensor] = None,
        update_hd: Optional[bool] = True
658
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
        if envs.USE_FUSED_RMS_QUANT and rms_weight is not None:
            input_quant_args = None
            assert rms_weight is not None 
            i_q, _scales = lm_faster_rmsquant(input=input_,
                                        rms_weight=rms_weight,
                                        epsilon=self.eps,
                                        quant_dtype=torch.int8,
                                        residual=residual,
                                        update_input=update_hd)
            new_residual = residual
            input_quant_args = [i_q, _scales]
        
            bias = self.bias if not self.skip_bias_add else None
            
            assert self.quant_method is not None
            output_parallel = self.quant_method.apply(self, input_, bias, input_quant_args)
            if self.gather_output and self.tp_size > 1:
                output = tensor_model_parallel_all_gather(output_parallel)
            else:
                output = output_parallel
            output_bias = self.bias if self.skip_bias_add else None
            if not self.return_bias:
                return output
            return output, new_residual, output_bias
683
        else:
684
685
686
687
688
689
690
691
692
693
694
695
696
            bias = self.bias if not self.skip_bias_add else None
            # Matrix multiply.
            assert self.quant_method is not None
            output_parallel = self.quant_method.apply(self, input_, bias)
            if self.gather_output and self.tp_size > 1:
                # All-gather across the partitions.
                output = tensor_model_parallel_all_gather(output_parallel)
            else:
                output = output_parallel
            output_bias = self.bias if self.skip_bias_add else None
            if not self.return_bias:
                return output
            return output, output_bias
697

698
699
700
701
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size_per_partition}"
        s += f", bias={self.bias is not None}"
702
        s += f", tp_size={self.tp_size}"
703
704
705
        s += f", gather_output={self.gather_output}"
        return s

706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724

class MergedColumnParallelLinear(ColumnParallelLinear):
    """Packed linear layers with column parallelism.

    Similar to ColumnParallelLinear, but the weight matrix is concatenated
    along the output dimension. When the weight matrix is loaded, the
    different partitions are sharded separately.

    Args:
        input_size: input dimension of the linear layer.
        output_sizes: list of output dimensions of the linear layer.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make the output
                       available to all GPUs, otherwise, every GPU will have
                       its own output.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
725
        quant_config: Quantization configure.
726
727
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
728
        return_bias: If true, return bias together with outputs in forward pass.
729
730
        disable_tp: If true, all weights matrix won't be sharded, this layer
                    will be treated as a "Replicated" MergedLinear.
731
732
    """

733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
    def forward(
        self, input_,
        rms_weight: Optional[torch.Tensor] = None,
        residual: Optional[torch.Tensor] = None,
        update_hd: Optional[bool] = True
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
        if envs.USE_FUSED_RMS_QUANT and rms_weight is not None:
            input_quant_args = None
            assert residual is not None and rms_weight is not None 
            i_q, _scales = lm_faster_rmsquant(input=input_,
                                        rms_weight=rms_weight,
                                        epsilon=self.eps,
                                        quant_dtype=torch.int8,
                                        residual=residual,
                                        update_input=update_hd)
            
            new_residual = residual
            input_quant_args = [i_q, _scales]
            
            
            bias = self.bias if not self.skip_bias_add else None
            assert self.quant_method is not None
            output_parallel = self.quant_method.apply(self, input_, bias, input_quant_args)
            
            if self.gather_output:
                # All-gather across the partitions.
                output = tensor_model_parallel_all_gather(output_parallel)
            else:
                output = output_parallel
            output_bias = self.bias if self.skip_bias_add else None
            if not self.return_bias:
                return output
            return output, new_residual, output_bias
        else: # not USE_FUSED_RMS_QUANT
            bias = self.bias if not self.skip_bias_add else None

            assert self.quant_method is not None
            output_parallel = self.quant_method.apply(self, input_, bias)
            if self.gather_output and self.tp_size > 1:
                # All-gather across the partitions.
                output = tensor_model_parallel_all_gather(output_parallel)
            else:
                output = output_parallel
            output_bias = self.bias if self.skip_bias_add else None
            if not self.return_bias:
                return output
            return output, output_bias

781
782
783
784
785
786
787
788
789
    def __init__(
        self,
        input_size: int,
        output_sizes: list[int],
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
790
        eps: Optional[float] = 1e-6,
791
792
793
        prefix: str = "",
        *,
        return_bias: bool = True,
794
        disable_tp: bool = False,
795
    ):
796
        self.eps = eps
797
        self.output_sizes = output_sizes
798
799
800
801
        self.tp_size = (get_tensor_model_parallel_world_size()
                        if not disable_tp else 1)
        self.tp_rank = (get_tensor_model_parallel_rank()
                        if not disable_tp else 0)
802
803
804

        assert all(output_size % self.tp_size == 0
                   for output_size in output_sizes)
805
806
807
808
809
810
        super().__init__(input_size=input_size,
                         output_size=sum(output_sizes),
                         bias=bias,
                         gather_output=gather_output,
                         skip_bias_add=skip_bias_add,
                         params_dtype=params_dtype,
811
                         quant_config=quant_config,
812
                         prefix=prefix,
813
814
                         return_bias=return_bias,
                         disable_tp=disable_tp)
815
        self.is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
816
817
818
819
820

    def weight_loader(self,
                      param: Parameter,
                      loaded_weight: torch.Tensor,
                      loaded_shard_id: Optional[int] = None):
James Fleming's avatar
James Fleming committed
821

822
823
824
825
826
        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
827
828
829
830
831
832
833
834
            if loaded_shard_id is not None:
                param.data[loaded_shard_id].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
                param.shard_weight_type = {
                    i: loaded_weight.item()
                    for i, _ in enumerate(self.output_sizes)
                }
835
836
            return

837
        if is_gguf_weight:
838

839
            output_dim = getattr(param, "output_dim", None)
840
841
            shard_size = loaded_weight.size(output_dim) // self.tp_size
            start_idx = self.tp_rank * shard_size
842

843
844
845
846
847
848
849
            if loaded_shard_id is not None:
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
850

851
852
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
853
854
        # Special case for per-tensor scale to load scalar into fused array.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
855

856
        if loaded_shard_id is None:
857
858
            # Loaded weight is already fused on disk (mlp).
            # (e.g., Phi-3's gate_up_proj).
859
            if output_dim is None:
860
                if needs_scalar_to_array:
861
862
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
                        param_data, loaded_weight, 0)
863

864
865
866
867
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            current_shard_offset = 0
868
869
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)
870
            shard_offsets: list[tuple[int, int, int]] = []
871
872
873
874
875
            for i, output_size in enumerate(self.output_sizes):
                shard_offsets.append((i, current_shard_offset, output_size))
                current_shard_offset += output_size
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
876
                # Special case for Quantization.
877
878
879
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
880
881
                    shard_size = shard_size // param.packed_factor
                    shard_offset = shard_offset // param.packed_factor
882
                    # Special case for Marlin.
883
884
885
                    shard_size, shard_offset = adjust_marlin_shard(
                        param, shard_size, shard_offset)

886
887
888
                shard_size, shard_offset = adjust_bitblas_shard(
                    param, shard_size, shard_offset)

889
                if use_bitsandbytes_4bit:
890
891
892
893
894
895
896
897
898
                    index = list(itertools.accumulate([0] + self.output_sizes))
                    orig_offsets = {
                        str(i): (index[i], size)
                        for i, size in enumerate(self.output_sizes)
                    }
                    orig_offsets["total"] = (self.output_size, 0)
                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
                        param, orig_offsets, str(shard_id))

899
900
901
902
903
904
905
                loaded_weight_shard = loaded_weight.narrow(
                    output_dim, shard_offset, shard_size)
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id < len(self.output_sizes)
        if output_dim is not None:
906
907
908
            shard_offset = (sum(self.output_sizes[:loaded_shard_id]) //
                            self.tp_size)
            shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
909
            # Special case for quantization.
910
911
912
913
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
914
915
                shard_size = shard_size // param.packed_factor
                shard_offset = shard_offset // param.packed_factor
916
                # Special case for Marlin.
917
918
                shard_size, shard_offset = adjust_marlin_shard(
                    param, shard_size, shard_offset)
919
920
            shard_size, shard_offset = adjust_bitblas_shard(
                param, shard_size, shard_offset)
gaoqiong's avatar
gaoqiong committed
921

922
923
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)
924
925
926
927
928
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

929
            if use_bitsandbytes_4bit:
930
931
932
                shard_size = loaded_weight.shape[output_dim]
                shard_offset = loaded_weight.shape[output_dim] * \
                    loaded_shard_id
933
                    
934
            if not envs.VLLM_USE_NN or self.is_quantization:
935
936
937
                param_data = param_data.narrow(output_dim, shard_offset, shard_size)
            else:
                param_data = param_data.narrow(int(not(output_dim)), shard_offset, shard_size)
938

939
            start_idx = self.tp_rank * shard_size
940
            if not is_sharded_weight:
941
942
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)
943
944
945
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
946
947
                param_data, loaded_weight, loaded_shard_id)

948
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
949
950
951
952
953
954
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "MergedColumnParallelLinear, assume the weight is "
                    "the same for all partitions.")
955

956
        if envs.VLLM_USE_NN and not self.is_quantization:
957
958
            loaded_weight = loaded_weight.t()
            
gaoqiong's avatar
gaoqiong committed
959
960
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)
961

962
963
964
965
966
    def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
                                           loaded_weight: torch.Tensor):
        """
        Handle special case for models where MLP layers are already
        fused on disk. In this case, we have no shard id. This function
967
        determines the shard id by splitting these layers and then calls
968
969
970
971
972
973
974
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """

        current_shard_offset = 0
975
        shard_offsets: list[tuple[int, int, int]] = []
976
977
978
979
980
981
982
983
        for i, output_size in enumerate(self.output_sizes):
            shard_offsets.append((i, current_shard_offset, output_size))
            current_shard_offset += output_size

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
984
985
            if isinstance(param, (PackedColumnParameter, PackedvLLMParameter
                                  )) and param.packed_dim == param.output_dim:
986
987
                shard_size, shard_offset = \
                    param.adjust_shard_indexes_for_packing(
988
989
990
991
992
993
994
995
996
997
998
                    shard_size=shard_size, shard_offset=shard_offset)

            loaded_weight_shard = loaded_weight.narrow(param.output_dim,
                                                       shard_offset,
                                                       shard_size)
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

    def weight_loader_v2(self,
                         param: BasevLLMParameter,
                         loaded_weight: torch.Tensor,
                         loaded_shard_id: Optional[int] = None):
999
        
1000
        if loaded_shard_id is None:
1001
1002
1003
1004
            if isinstance(param, PerTensorScaleParameter):
                param.load_merged_column_weight(loaded_weight=loaded_weight,
                                                shard_id=0)
                return
1005
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
1006
                param.load_merged_column_weight(loaded_weight=loaded_weight)
1007
                return
1008
            # TODO: @dsikka - move to parameter.py
1009
1010
1011
1012
1013
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id < len(self.output_sizes)

1014
1015
        if isinstance(param, BlockQuantScaleParameter):
            assert self.quant_method is not None
1016
1017
1018
            # Assume the weight block size has been set by quant method
            assert hasattr(self, "weight_block_size")
            weight_block_size = self.weight_block_size
1019
1020
1021
1022
            assert weight_block_size is not None
            block_n, _ = weight_block_size[0], weight_block_size[1]
            shard_offset = (
                (sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) //
1023
                block_n) // self.tp_size
1024
            shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) //
1025
                          block_n // self.tp_size)
1026
        else:
1027
1028
1029
            shard_offset = sum(
                self.output_sizes[:loaded_shard_id]) // self.tp_size
            shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
1030

1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
        if not envs.VLLM_USE_NN:
            param.load_merged_column_weight(loaded_weight=loaded_weight,
                                            shard_id=loaded_shard_id,
                                            shard_offset=shard_offset,
                                            shard_size=shard_size,
                                            tp_rank=self.tp_rank)
        else:
            param.load_merged_column_weight(loaded_weight=loaded_weight,
                                            shard_id=loaded_shard_id,
                                            shard_offset=shard_offset,
                                            shard_size=shard_size,
                                            tp_rank=self.tp_rank,
1043
                                            is_quantization=self.is_quantization)
1044

1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066

class QKVParallelLinear(ColumnParallelLinear):
    """Linear layers for the attention's QKV transformation.

    Linear layers for the linear transformation of the query, key, and value
    vectors in the attention layer. The weight matrix is concatenated along
    the output dimension. The layer is parallelized along the head dimension.
    When the number of key/value heads is smaller than the number of query
    heads (e.g., multi-query/grouped-query attention), the key/value head may
    be replicated while the query heads are partitioned.

    Args:
        hidden_size: input hidden state size of the transformer.
        head_size: size of each attention head.
        total_num_heads: total number of attention query heads.
        total_num_kv_heads: total number of attention key/value heads. If
                            None, assume total_num_kv_heads = total_num_heads.
        bias: If true, add bias.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
1067
        quant_config: Quantization configure.
1068
1069
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
1070
        return_bias: If true, return bias together with outputs in forward pass.
1071
        disable_tp: If true, weights matrix won't be sharded through tp rank.
1072
1073
    """

1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
    def __init__(
        self,
        hidden_size: int,
        head_size: int,
        total_num_heads: int,
        total_num_kv_heads: Optional[int] = None,
        bias: bool = True,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        *,
        return_bias: bool = True,
1087
        disable_tp: bool = False,
1088
    ):
1089
1090
1091
1092
1093
1094
1095
        self.hidden_size = hidden_size
        self.head_size = head_size
        self.total_num_heads = total_num_heads
        if total_num_kv_heads is None:
            total_num_kv_heads = total_num_heads
        self.total_num_kv_heads = total_num_kv_heads
        # Divide the weight matrix along the last dimension.
1096
1097
        tp_size = (get_tensor_model_parallel_world_size()
                   if not disable_tp else 1)
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
        self.num_heads = divide(self.total_num_heads, tp_size)
        if tp_size >= self.total_num_kv_heads:
            self.num_kv_heads = 1
            self.num_kv_head_replicas = divide(tp_size,
                                               self.total_num_kv_heads)
        else:
            self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
            self.num_kv_head_replicas = 1
        input_size = self.hidden_size
        output_size = (self.num_heads +
                       2 * self.num_kv_heads) * tp_size * self.head_size
1109
1110
1111
1112
        self.output_sizes = [
            self.num_heads * self.head_size * tp_size,  # q_proj
            self.num_kv_heads * self.head_size * tp_size,  # k_proj
            self.num_kv_heads * self.head_size * tp_size,  # v_proj 
James Fleming's avatar
James Fleming committed
1113
        ]
gaoqiong's avatar
gaoqiong committed
1114

1115
1116
1117
1118
1119
1120
        super().__init__(input_size=input_size,
                         output_size=output_size,
                         bias=bias,
                         gather_output=False,
                         skip_bias_add=skip_bias_add,
                         params_dtype=params_dtype,
1121
                         quant_config=quant_config,
1122
                         prefix=prefix,
1123
1124
                         return_bias=return_bias,
                         disable_tp=disable_tp)
1125
        self.is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
1126

1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
    def _get_shard_offset_mapping(self, loaded_shard_id: str):
        shard_offset_mapping = {
            "q": 0,
            "k": self.num_heads * self.head_size,
            "v": (self.num_heads + self.num_kv_heads) * self.head_size,
            "total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size
        }
        return shard_offset_mapping.get(loaded_shard_id)

    def _get_shard_size_mapping(self, loaded_shard_id: str):
        shard_size_mapping = {
            "q": self.num_heads * self.head_size,
            "k": self.num_kv_heads * self.head_size,
            "v": self.num_kv_heads * self.head_size,
        }
        return shard_size_mapping.get(loaded_shard_id)

    def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
                                           loaded_weight: torch.Tensor):
        """
        Handle special case for models where QKV layers are already 
        fused on disk. In this case, we have no shard id. This function
1149
        determines the shard id by splitting these layers and then calls
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """
        shard_offsets = [
            # (shard_id, shard_offset, shard_size)
            ("q", 0, self.total_num_heads * self.head_size),
            ("k", self.total_num_heads * self.head_size,
             self.total_num_kv_heads * self.head_size),
            ("v",
             (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
             self.total_num_kv_heads * self.head_size),
        ]

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
1169
1170
            if isinstance(param, (PackedColumnParameter, PackedvLLMParameter
                                  )) and param.packed_dim == param.output_dim:
1171
1172
                shard_size, shard_offset = \
                    param.adjust_shard_indexes_for_packing(
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
                    shard_size=shard_size, shard_offset=shard_offset)

            loaded_weight_shard = loaded_weight.narrow(param.output_dim,
                                                       shard_offset,
                                                       shard_size)
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

    def weight_loader_v2(self,
                         param: BasevLLMParameter,
                         loaded_weight: torch.Tensor,
                         loaded_shard_id: Optional[str] = None):
1184
        
1185
        if loaded_shard_id is None:  # special case for certain models
1186
            if isinstance(param, PerTensorScaleParameter):
1187
1188
1189
                param.load_qkv_weight(loaded_weight=loaded_weight,
                                      shard_id=0,
                                      tp_rank=self.tp_rank)
1190
                return
1191
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
1192
1193
                param.load_qkv_weight(loaded_weight=loaded_weight,
                                      tp_rank=self.tp_rank)
1194
                return
1195
            # TODO: @dsikka - move to parameter.py
1196
1197
1198
1199
1200
1201
1202
1203
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id in ["q", "k", "v"]

        shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
        shard_size = self._get_shard_size_mapping(loaded_shard_id)

1204
1205
1206
        # Note(simon): This is needed for Qwen3's fp8 quantization.
        if isinstance(param, BlockQuantScaleParameter):
            assert self.quant_method is not None
1207
1208
1209
1210
            # Assume the weight block size has been set by quant method
            assert hasattr(self, "weight_block_size")
            weight_block_size = self.weight_block_size
            assert weight_block_size is not None
1211
1212
1213
1214
            block_n, _ = weight_block_size[0], weight_block_size[1]
            shard_offset = (shard_offset + block_n - 1) // block_n
            shard_size = (shard_size + block_n - 1) // block_n

1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
        if not envs.VLLM_USE_NN:
            param.load_qkv_weight(loaded_weight=loaded_weight,
                                num_heads=self.num_kv_head_replicas,
                                shard_id=loaded_shard_id,
                                shard_offset=shard_offset,
                                shard_size=shard_size,
                                tp_rank=self.tp_rank)
        else:
            param.load_qkv_weight(loaded_weight=loaded_weight,
                                num_heads=self.num_kv_head_replicas,
                                shard_id=loaded_shard_id,
                                shard_offset=shard_offset,
                                shard_size=shard_size,
                                tp_rank=self.tp_rank,
1229
                                is_quantization=self.is_quantization)
1230

1231
1232
1233
1234
    def weight_loader(self,
                      param: Parameter,
                      loaded_weight: torch.Tensor,
                      loaded_shard_id: Optional[str] = None):
1235
1236
1237
1238
1239

        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
1240
        if is_gguf_weight_type:
1241
            idx_map = {"q": 0, "k": 1, "v": 2}
1242
1243
1244
1245
1246
1247
1248
1249
            if loaded_shard_id is not None:
                param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
                param.shard_weight_type = {
                    k: loaded_weight.item()
                    for k in idx_map
                }
1250
1251
            return

1252
1253
        if is_gguf_weight:
            output_dim = getattr(param, "output_dim", None)
1254
1255
            shard_size = loaded_weight.size(output_dim) // self.tp_size
            start_idx = self.tp_rank * shard_size
1256

1257
1258
1259
1260
1261
1262
1263
            if loaded_shard_id is not None:
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
1264

1265
1266
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
1267

1268
1269
        # Special case for per-tensor scales in fused case.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
1270

1271
        if loaded_shard_id is None:
1272
1273
            # Loaded weight is already fused on disk (qkv).
            # (e.g., Phi-3's qkv_proj).
1274
            if output_dim is None:
1275
                if needs_scalar_to_array:
1276
1277
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
                        param_data, loaded_weight, 0)
1278

1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            shard_offsets = [
                # (shard_id, shard_offset, shard_size)
                ("q", 0, self.total_num_heads * self.head_size),
                ("k", self.total_num_heads * self.head_size,
                 self.total_num_kv_heads * self.head_size),
                ("v", (self.total_num_heads + self.total_num_kv_heads) *
                 self.head_size, self.total_num_kv_heads * self.head_size),
            ]
1290
1291
1292
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)

1293
1294
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
1295
                # Special case for Quantized Weights.
1296
1297
1298
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
1299
1300
                    shard_size = shard_size // param.packed_factor
                    shard_offset = shard_offset // param.packed_factor
1301

1302
                    # Special case for Marlin.
1303
1304
1305
                    shard_size, shard_offset = adjust_marlin_shard(
                        param, shard_size, shard_offset)

1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
                if use_bitsandbytes_4bit:
                    orig_qkv_offsets = {
                        "q": (0, self.total_num_heads * self.head_size),
                        "k": (self.total_num_heads * self.head_size,
                              self.total_num_kv_heads * self.head_size),
                        "v":
                        ((self.total_num_heads + self.total_num_kv_heads) *
                         self.head_size,
                         self.total_num_kv_heads * self.head_size),
                        "total":
                        ((self.total_num_heads + 2 * self.total_num_kv_heads) *
                         self.head_size, 0)
                    }

                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
                        param, orig_qkv_offsets, shard_id)

1323
1324
1325
1326
1327
1328
                loaded_weight_shard = loaded_weight.narrow(
                    output_dim, shard_offset, shard_size)
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id in ["q", "k", "v"]
1329
1330

        # If output dim is defined, use the default loading process.
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
        if output_dim is not None:
            if loaded_shard_id == "q":
                shard_offset = 0
                shard_size = self.num_heads * self.head_size
            elif loaded_shard_id == "k":
                shard_offset = self.num_heads * self.head_size
                shard_size = self.num_kv_heads * self.head_size
            elif loaded_shard_id == "v":
                shard_offset = (self.num_heads +
                                self.num_kv_heads) * self.head_size
                shard_size = self.num_kv_heads * self.head_size
1342
            # Special case for Quantized Weights.
1343
1344
1345
1346
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
1347
1348
                shard_size = shard_size // param.packed_factor
                shard_offset = shard_offset // param.packed_factor
1349

1350
                # Special case for Marlin.
1351
1352
1353
                shard_size, shard_offset = adjust_marlin_shard(
                    param, shard_size, shard_offset)

1354
1355
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)
1356
1357
1358
1359
1360
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

1361
            if use_bitsandbytes_4bit:
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
                orig_qkv_offsets = {
                    "q": (0, self.num_heads * self.head_size),
                    "k": (self.num_heads * self.head_size,
                          self.num_kv_heads * self.head_size),
                    "v":
                    ((self.num_heads + self.num_kv_heads) * self.head_size,
                     self.num_kv_heads * self.head_size),
                    "total":
                    ((self.num_heads + 2 * self.num_kv_heads) * self.head_size,
                     0)
                }
1373
                shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
1374
                    param, orig_qkv_offsets, loaded_shard_id)
gaoqiong's avatar
gaoqiong committed
1375

1376
            if not envs.VLLM_USE_NN or len(param_data.shape)==1 or self.is_quantization:
1377
1378
1379
1380
1381
1382
                param_data = param_data.narrow(output_dim, shard_offset,
                                               shard_size)
            else:
                param_data = param_data.narrow(int(not(output_dim)), shard_offset,
                                               shard_size)
                
zhuwenwen's avatar
zhuwenwen committed
1383
            if loaded_shard_id == "q":
1384
                shard_id = self.tp_rank
1385
            else:
1386
                shard_id = self.tp_rank // self.num_kv_head_replicas
1387
            start_idx = shard_id * shard_size
1388

1389
            if not is_sharded_weight:
1390
1391
1392
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)

1393
1394
1395
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
1396
                param_data, loaded_weight, loaded_shard_id)
1397
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
1398
1399
1400
1401
1402
1403
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "QKVParallelLinear, assume the weight is the same "
                    "for all partitions.")
gaoqiong's avatar
gaoqiong committed
1404

1405
        if envs.VLLM_USE_NN and not self.is_quantization:
1406
1407
            loaded_weight = loaded_weight.t()
            
gaoqiong's avatar
gaoqiong committed
1408
1409
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)
1410
1411


1412
@CustomOp.register("row_parallel_linear")
1413
class RowParallelLinear(LinearBase):
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
    """Linear layer with row parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its first dimension and X along its second dimension as:
               -   -
              | A_1 |
              | .   |
          A = | .   |        X = [X_1, ..., X_p]
              | .   |
              | A_p |
               -   -
    Arguments:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias. Note that bias is not parallelized.
        input_is_parallel: If true, we assume that the input is already
                           split across the GPUs and we do not split
                           again.
        skip_bias_add: This was added to enable performance optimization where
                       bias can be fused with other element-wise operations.
                       We skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
1436
1437
1438
        reduce_results: If true, call all-reduce on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y = X_iA_i
1439
        quant_config: Quantization configure.
1440
1441
1442
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.down_proj)
        return_bias: If true, return bias together with outputs in forward pass.
1443
        disable_tp: If true, weights matrix won't be sharded through tp rank.
1444
1445
    """

1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        input_is_parallel: bool = True,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        reduce_results: bool = True,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        *,
        return_bias: bool = True,
1459
        disable_tp: bool = False,
1460
    ):
1461
        # Divide the weight matrix along the first dimension.
1462
1463
1464
1465
        self.tp_rank = (get_tensor_model_parallel_rank()
                        if not disable_tp else 0)
        self.tp_size = (get_tensor_model_parallel_world_size()
                        if not disable_tp else 1)
1466
1467
1468
1469
        self.input_size_per_partition = divide(input_size, self.tp_size)
        self.output_size_per_partition = output_size
        self.output_partition_sizes = [output_size]

1470
1471
1472
1473
1474
1475
        super().__init__(input_size,
                         output_size,
                         skip_bias_add,
                         params_dtype,
                         quant_config,
                         prefix,
1476
1477
                         return_bias=return_bias,
                         disable_tp=disable_tp)
1478

1479
1480
1481
        self.input_is_parallel = input_is_parallel
        self.reduce_results = reduce_results

1482
        assert self.quant_method is not None
1483
1484
1485
        self.quant_method.create_weights(
            layer=self,
            input_size_per_partition=self.input_size_per_partition,
1486
            output_partition_sizes=self.output_partition_sizes,
1487
1488
1489
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
1490
1491
            weight_loader=(
                self.weight_loader_v2 if self.quant_method.__class__.__name__
1492
                in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
1493
1494
1495
1496
1497
1498
        if not reduce_results and (bias and not skip_bias_add):
            raise ValueError("When not reduce the results, adding bias to the "
                             "results can lead to incorrect results")

        if bias:
            self.bias = Parameter(
1499
                torch.empty(self.output_size, dtype=params_dtype))
1500
1501
1502
1503
1504
1505
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
        else:
            self.register_parameter("bias", None)
1506

1507
        self.update_param_tp_status()
1508
        self.is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
1509
1510
1511

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        input_dim = getattr(param, "input_dim", None)
1512
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1513
1514
1515
1516
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
1517
1518
1519
1520
1521
1522
1523
1524
1525

        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
1526
1527
            weight_shape = list(loaded_weight.shape)
            if input_dim:
1528
1529
                weight_shape[input_dim] = (weight_shape[input_dim] //
                                           self.tp_size)
1530
            param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
1531

1532
        param_data = param.data
1533
        if input_dim is not None and not is_sharded_weight:
1534
            if not envs.VLLM_USE_NN or self.is_quantization:
1535
1536
1537
                shard_size = param_data.shape[input_dim]
            else:
                shard_size = param_data.shape[int(not(input_dim))]
1538
            start_idx = self.tp_rank * shard_size
1539
1540
            loaded_weight = loaded_weight.narrow(input_dim, start_idx,
                                                 shard_size)
1541

1542
1543
1544
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
1545
1546
            loaded_weight = loaded_weight.reshape(1)

1547
        if envs.VLLM_USE_NN and not self.is_quantization:
1548
1549
            loaded_weight = loaded_weight.t()
            
1550
1551
1552
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

1553
1554
    def weight_loader_v2(self, param: BasevLLMParameter,
                         loaded_weight: torch.Tensor):
1555
1556
1557
1558
1559
1560
1561

        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)

1562
        param.load_row_parallel_weight(loaded_weight=loaded_weight if not envs.VLLM_USE_NN or self.is_quantization else loaded_weight.t())
1563

1564
    def forward(
1565
1566
        self,
        input_,
1567
        use_fused_silu_mul_quant: Optional[bool] = False
1568
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
1569
1570
1571
1572
1573
        if self.input_is_parallel:
            input_parallel = input_
        else:
            splitted_input = split_tensor_along_last_dim(
                input_, num_partitions=self.tp_size)
1574
            input_parallel = splitted_input[self.tp_rank].contiguous()
1575
1576

        # Matrix multiply.
1577
        assert self.quant_method is not None
1578
1579
1580
        # Only fuse bias add into GEMM for rank 0 (this ensures that
        # bias will not get added more than once in TP>1 case)
        bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
1581
1582
1583
1584
1585
1586
        if use_fused_silu_mul_quant:
            xq, xs = lm_fuse_silu_mul_quant(input_parallel)
            
            silu_quant_args = [xq, xs]
            output_parallel = self.quant_method.apply(self,
                                                      input_parallel,
1587
                                                      bias_,
1588
1589
                                                      silu_quant_args=silu_quant_args)
        else:
1590
            output_parallel = self.quant_method.apply(self, input_parallel, bias_)
1591

1592
        if self.reduce_results and self.tp_size > 1:
zhuwenwen's avatar
zhuwenwen committed
1593
            output = tensor_model_parallel_all_reduce(output_parallel)
1594
        else:
1595
1596
1597
            output = output_parallel

        output_bias = self.bias if self.skip_bias_add else None
1598

1599
1600
        if not self.return_bias:
            return output
1601
        return output, output_bias
1602
1603

    def extra_repr(self) -> str:
1604
        s = f"in_features={self.input_size_per_partition}"
1605
1606
1607
1608
1609
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={self.tp_size}"
        s += f", reduce_results={self.reduce_results}"
        return s
1610
1611


1612
@CustomOp.register("qkv_cross_parallel_linear")
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
class QKVCrossParallelLinear(LinearBase):
    """Linear layers for efficient cross-attention's QKV transformation.

    Args:
        hidden_size: input hidden state size of the transformer.
        head_size: size of each attention head.
        total_num_heads: total number of attention query heads.
        total_num_kv_heads: total number of attention key/value heads. If
                            None, assume total_num_kv_heads = total_num_heads.
        bias: If true, add bias.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
        quant_config: Quantization configure.
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
    """
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641

    def __init__(self,
                 hidden_size: int,
                 head_size: int,
                 total_num_heads: int,
                 total_num_kv_heads: Optional[int] = None,
                 bias: bool = True,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
        # input_size and output_size are not used, just for alignment
        input_size = hidden_size
        output_size = (total_num_heads + (total_num_kv_heads or 0)) * head_size
        super().__init__(input_size=input_size,
                         output_size=output_size,
                         skip_bias_add=skip_bias_add,
                         params_dtype=params_dtype,
                         quant_config=quant_config,
                         prefix=prefix)

        self.quant_config = quant_config

1654
        # Empty placeholders for loading as a single module.
1655
1656
1657
1658
1659
1660
1661
1662
1663
        placeholder_size = 0
        assert self.quant_method is not None
        self.quant_method.create_weights(self,
                                         placeholder_size, [placeholder_size],
                                         placeholder_size,
                                         placeholder_size,
                                         self.params_dtype,
                                         weight_loader=self.weight_loader)

1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
        # Use a dictionary to avoid submodules parameters auto-registration:
        # drop-in replacement for a `QKVParallelLinear` module.
        self.proj = dict()
        self.proj["q_proj_decoder"] = ColumnParallelLinear(
            input_size=hidden_size,
            output_size=total_num_heads * head_size,
            bias=bias,
            quant_config=quant_config,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            prefix=f"{prefix}.q_proj_decoder")

        self.proj["kv_proj_encoder"] = QKVParallelLinear(
            hidden_size=hidden_size,
            head_size=head_size,
            total_num_heads=0,
            total_num_kv_heads=total_num_kv_heads,
            bias=bias,
            quant_config=quant_config,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            prefix=f"{prefix}.kv_proj_encoder")

        # `kv_proj_encoder.num_kv_heads` accounts for sharding with tp>1.
1688
        self.q_size = self.q_proj_decoder.output_size_per_partition
1689
1690
1691
1692
1693
        self.kv_size = self.kv_proj_encoder.num_kv_heads * head_size

        if bias:
            self.bias = torch.nn.Parameter()
            set_weight_attrs(self.bias, {
1694
                "output_dim": 0,
1695
                "weight_loader": self.weight_loader_v1,
1696
            })
1697
1698
        else:
            self.bias = None
1699

1700
1701
1702
1703
1704
    def process_weights_after_loading(self):
        for layer in self.proj.values():
            if self.quant_method is not None:
                self.quant_method.process_weights_after_loading(layer)

1705
    @property
1706
1707
1708
    def q_proj_decoder(self) -> ColumnParallelLinear:
        layer = self.proj["q_proj_decoder"]
        for name, param in self.named_parameters():
1709
1710
1711
1712
1713
            target_param = getattr(layer, name, None)
            if target_param is not None:
                self.sync_weight_attrs(param,
                                       target_param,
                                       mode="q_proj_decoder")
1714
        return layer
1715
1716

    @property
1717
1718
1719
    def kv_proj_encoder(self) -> QKVParallelLinear:
        layer = self.proj["kv_proj_encoder"]
        for name, param in self.named_parameters():
1720
1721
1722
1723
1724
            target_param = getattr(layer, name, None)
            if target_param is not None:
                self.sync_weight_attrs(param,
                                       target_param,
                                       mode="kv_proj_encoder")
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
        return layer

    def sync_weight_attrs(
        self,
        src_param: nn.Parameter,
        tgt_param: nn.Parameter,
        mode: Literal["q_proj_decoder", "kv_proj_encoder"],
    ):
        missing_attrs_dict = {
            k: getattr(src_param, k)
1735
1736
            for k in (set(vars(src_param).keys()) -
                      set(vars(tgt_param).keys()))
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
        }
        # TODO(Isotr0py): handle bitsandbytes 8bit
        use_bitsandbytes_4bit = getattr(src_param, "use_bitsandbytes_4bit",
                                        False)
        if (missing_attrs_dict and use_bitsandbytes_4bit):
            q_proj_attrs, kv_proj_attrs = left_shift_bitsandbytes_4bit_shard(
                missing_attrs_dict)
            if mode == "q_proj_decoder":
                set_weight_attrs(tgt_param, q_proj_attrs)
            elif mode == "kv_proj_encoder":
                set_weight_attrs(tgt_param, kv_proj_attrs)
        else:
            set_weight_attrs(tgt_param, missing_attrs_dict)
1750

1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
    def _is_same_param(
        self,
        src_param: torch.nn.Parameter,
        map_param: torch.nn.Parameter,
    ) -> bool:
        """Check if two parameters are exactly pointing to same things."""
        # ignore weight_loader because it's always different
        key_to_ignore = ["weight_loader", "_weight_loader"]
        has_same_type_name = type(src_param) is type(map_param)
        src_param_attrs = {
            k: v
            for k, v in src_param.__dict__.items() if k not in key_to_ignore
        }
        map_param_attrs = {
            k: v
            for k, v in map_param.__dict__.items() if k not in key_to_ignore
        }
        has_same_attrs = src_param_attrs == map_param_attrs
        return has_same_type_name and has_same_attrs

    def select_proj_params(
        self,
        layer: nn.Module,
        param: nn.Parameter,
    ) -> nn.Parameter:
        """
        Given the placeholder param, 
        return the corresponding param in the proj layers.
        """
        target_param_list = [
            v for _, v in layer.named_parameters()
            if self._is_same_param(param, v)
        ]
        assert len(target_param_list) == 1
        target_param = target_param_list[0]
        return target_param

    def forward(  # type: ignore[override]
        self,
        decoder_hidden_states: torch.Tensor,
        encoder_hidden_states: torch.Tensor,
    ) -> tuple[torch.Tensor, ...]:
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
        q, _ = self.q_proj_decoder(decoder_hidden_states)
        if encoder_hidden_states is None:
            # Encoder KV already cached.
            k = None
            v = None
        else:
            # Prefill phase, encoder KV cached here.
            kv_enc, _ = self.kv_proj_encoder(encoder_hidden_states)
            # Split kv in half
            k, v = kv_enc.split(self.kv_size, dim=-1)
        return q, k, v

1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
    def weight_loader_v1(self,
                         param: torch.nn.Parameter,
                         loaded_weight: torch.Tensor,
                         loaded_shard_id: Optional[str] = None):
        # just like all other parameters, does not yet
        # support loading bias with weight_loader_v2
        layer = (self.q_proj_decoder
                 if loaded_shard_id == "q" else self.kv_proj_encoder)
        target_param = self.select_proj_params(layer, param)
        shard_id_args = (loaded_shard_id, ) if loaded_shard_id != "q" else ()
        layer.weight_loader(target_param, loaded_weight, *shard_id_args)

1817
1818
1819
1820
1821
1822
1823
1824
    def weight_loader(self,
                      param: torch.nn.Parameter,
                      loaded_weight: torch.Tensor,
                      loaded_shard_id: Optional[str] = None):
        layer = (self.q_proj_decoder
                 if loaded_shard_id == "q" else self.kv_proj_encoder)
        target_param = self.select_proj_params(layer, param)
        shard_id_args = (loaded_shard_id, ) if loaded_shard_id != "q" else ()
1825
1826
1827
1828
        if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED:
            layer.weight_loader_v2(target_param, loaded_weight, *shard_id_args)
        else:
            layer.weight_loader(target_param, loaded_weight, *shard_id_args)
1829
1830
1831

    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
1832
        s += f", q_size={self.q_size}"
1833
1834
1835
1836
        s += f", kv_size={self.kv_size}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={get_tensor_model_parallel_world_size()}"
        s += ", gather_output=False"
1837
        return s