linear.py 75.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
from abc import abstractmethod
6
from typing import Any, Literal, Optional, Union
7
import vllm.envs as envs
8
import torch
9
import torch.nn as nn
10
from torch.nn.parameter import Parameter, UninitializedParameter
11

12
13
14
15
16
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size,
                              split_tensor_along_last_dim,
                              tensor_model_parallel_all_gather,
                              tensor_model_parallel_all_reduce)
17
from vllm.logger import init_logger
18
from vllm.model_executor.custom_op import CustomOp
19
20
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig, QuantizeMethodBase)
21
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
22
# yapf: disable
23
from vllm.model_executor.parameter import (BasevLLMParameter,
24
                                           BlockQuantScaleParameter,
25
                                           PackedColumnParameter,
26
                                           PackedvLLMParameter,
27
28
                                           PerTensorScaleParameter,
                                           RowvLLMParameter)
29
# yapf: enable
30
from vllm.model_executor.utils import set_weight_attrs
31
from vllm.platforms import current_platform
gaoqiong's avatar
gaoqiong committed
32

zhuwenwen's avatar
zhuwenwen committed
33
import os
34
from vllm.model_executor.utils import gemm_bank_conf
35

36
37
38
39
40
41
if envs.USE_FUSED_RMS_QUANT:
    try:
        from lmslim.quantize.quant_ops import lm_faster_rmsquant
    except Exception as e:
        print(f"Error: Import fused rmsquant error: {e}") 

42
43
logger = init_logger(__name__)

44
WEIGHT_LOADER_V2_SUPPORTED = [
45
    "CompressedTensorsLinearMethod",
46
    "CompressedTensorsLinearTransformMethod",
47
48
    "BitBLASLinearMethod",
    "GPTQBitBLASLinearMethod",
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
    "AWQMarlinLinearMethod",
    "AWQLinearMethod",
    "GPTQMarlinLinearMethod",
    "Fp8LinearMethod",
    "MarlinLinearMethod",
    "GPTQMarlin24LinearMethod",
    "TPUInt8LinearMethod",
    "GPTQLinearMethod",
    "FBGEMMFp8LinearMethod",
    "ModelOptFp8LinearMethod",
    "IPEXAWQLinearMethod",
    "IPEXGPTQLinearMethod",
    "HQQMarlinMethod",
    "QuarkLinearMethod",
    "ModelOptNvFp4LinearMethod",
64
    "PetitNvFp4LinearMethod",
zhuwenwen's avatar
zhuwenwen committed
65
    "BlockInt8LinearMethod",
66
]
67

68

69
70
71
72
73
74
75
76
77
def adjust_bitblas_shard(param, shard_size, shard_offset):
    bitblas_tile_size = getattr(param, "bitblas_tile_size", None)
    if bitblas_tile_size is not None:
        return (shard_size // bitblas_tile_size,
                shard_offset // bitblas_tile_size)

    return shard_size, shard_offset


78
79
80
81
82
83
84
85
def adjust_marlin_shard(param, shard_size, shard_offset):
    marlin_tile_size = getattr(param, "marlin_tile_size", None)
    if marlin_tile_size is None:
        return shard_size, shard_offset

    return shard_size * marlin_tile_size, shard_offset * marlin_tile_size


86
def adjust_bitsandbytes_4bit_shard(param: Parameter,
87
88
                                   shard_offsets: dict[str, tuple[int, int]],
                                   loaded_shard_id: str) -> tuple[int, int]:
89
90
    """Adjust the quantization offsets and sizes for BitsAndBytes sharding."""

91
92
    total, _ = shard_offsets["total"]
    orig_offset, orig_size = shard_offsets[loaded_shard_id]
93
94
95
96
97
98
99
100

    quantized_total = param.data.shape[0]
    quantized_offset = orig_offset * quantized_total // total
    quantized_size = orig_size * quantized_total // total

    return quantized_size, quantized_offset


101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
    """For fused modules (QKV and MLP) we have an array of length
    N that holds 1 scale for each "logical" matrix. So the param
    is an array of length N. The loaded_weight corresponds to 
    one of the shards on disk. Here, we slice the param based on 
    the shard_id for loading.
    """
    qkv_idxs = {"q": 0, "k": 1, "v": 2}

    if isinstance(shard_id, str):
        shard_id = qkv_idxs[shard_id]
    elif not isinstance(shard_id, int):
        raise ValueError(f"Unknown Shard Id {shard_id}")

    # AutoFP8 scales do not have a shape
    # compressed-tensors scales do have a shape
    if len(loaded_weight.shape) != 0:
        assert loaded_weight.shape[0] == 1
        loaded_weight = loaded_weight[0]

121
122
123
124
    if envs.VLLM_USE_NN:
        return param[shard_id], loaded_weight.t()
    else:
        return param[shard_id], loaded_weight
125
126


127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
# TODO(Isotr0py): We might need a more flexible structure to handle
# bitsandbytes shard offsets.
def left_shift_bitsandbytes_4bit_shard(bnb_weight_attrs: dict[str, Any]):
    """
    Separate the BitsAndBytes 4-bit shard.

    For example, given bnb weight attributes as below:
    {
        'bnb_shard_offsets': array([0, 4, 8, 16]), 
        'bnb_quant_state': {0: ..., 1: ..., 2: ...},
    }

    The function will return:
    {
        'bnb_shard_offsets': array([0, 4]), 
        'bnb_quant_state': {0: ...},
    }
    and
    {
        'bnb_shard_offsets': array([0, 4, 12]),
        'bnb_quant_state': {0: ..., 1: ...},
    }
    """
    shard_offsets = bnb_weight_attrs["bnb_shard_offsets"]
    offset_l = shard_offsets[:2]
    offset_r = shard_offsets[1:] - shard_offsets[1]
    quant_state_l = {0: bnb_weight_attrs["bnb_quant_state"][0]}
    quant_state_r = {
        i - 1: bnb_weight_attrs["bnb_quant_state"][i]
        for i in range(1,
                       len(shard_offsets) - 1)
    }
    left = dict(bnb_shard_offsets=offset_l, bnb_quant_state=quant_state_l)
    right = dict(bnb_shard_offsets=offset_r, bnb_quant_state=quant_state_r)
    return left, right


164
class LinearMethodBase(QuantizeMethodBase):
165
166
167
    """Base class for different (maybe quantized) linear methods."""

    @abstractmethod
168
169
    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
170
                       output_partition_sizes: list[int], input_size: int,
171
172
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
173
174
        """Create weights for a linear layer. 
           The weights will be set as attributes of the layer.
175

176
177
178
179
180
181
182
183
184
185
        Args:
            layer: The layer that is using the LinearMethodBase factory.
            input_size_per_partition: Size of the weight input dim on rank X.
            output_partition_sizes: Sizes of the output dim of each logical 
                weight on rank X. E.g., output_partition_sizes for QKVLinear
                is a list contains the width of Wq, Wk, Wv on rank X.
            input_size: Size of the input dim of the weight across all ranks.
            output_size: Size of the output dim of the weight across all ranks.
            params_dtype: Datatype of the parameters.
        """
186
187
188
        raise NotImplementedError

    @abstractmethod
189
190
191
192
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
193
194
        """Apply the weights in layer to the input tensor.
        Expects create_weights to have been called before on the layer."""
195
196
197
198
        raise NotImplementedError


class UnquantizedLinearMethod(LinearMethodBase):
199
    """Linear method without quantization."""
200
201
    
    def __init__(self):
zhuwenwen's avatar
zhuwenwen committed
202
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
203
        self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
gaoqiong's avatar
gaoqiong committed
204
        
205
206
    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
207
                       output_partition_sizes: list[int], input_size: int,
208
209
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
210
211
212
        if envs.VLLM_USE_NN:
            weight = Parameter(torch.empty(input_size_per_partition,
                                       sum(output_partition_sizes),
213
214
                                       dtype=params_dtype),
                           requires_grad=False)
215
216
217
218
219
        else:
            weight = Parameter(torch.empty(sum(output_partition_sizes),
                                           input_size_per_partition,
                                           dtype=params_dtype),
                           requires_grad=False)
220
        set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
221
222
        layer.register_parameter("weight", weight)
        set_weight_attrs(weight, extra_weight_attrs)
223

224
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
225
226
227
228
        if current_platform.is_cpu():
            from vllm.model_executor.layers.utils import (
                dispatch_cpu_unquantized_gemm)
            dispatch_cpu_unquantized_gemm(layer, remove_weight=True)
229

230
231
232
233
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
zhuwenwen's avatar
zhuwenwen committed
234
        if self.use_llama_nn:
235
236
            if gemm_bank_conf(layer.weight.shape[1] - 32) and os.environ['GEMM_PAD'] == '1':
                layer.weight = layer.weight[:,:-32]
237
                
zhuwenwen's avatar
zhuwenwen committed
238
            if bias is not None:
zhuwenwen's avatar
zhuwenwen committed
239
                if len(x.shape) == 2: 
240
                    return torch.addmm(bias, x, layer.weight)
zhuwenwen's avatar
zhuwenwen committed
241
                else:
242
                    return torch.matmul(x, layer.weight) + bias
zhuwenwen's avatar
zhuwenwen committed
243
            else:
244
                return torch.matmul(x, layer.weight)
zhuwenwen's avatar
zhuwenwen committed
245
        else:
246
247
248
249
            if envs.VLLM_USE_NN and x.shape[-1] == layer.weight.shape[0]:
                return dispatch_unquantized_gemm()(x, layer.weight.t(), bias)
            else:
                return dispatch_unquantized_gemm()(x, layer.weight, bias)
250

251

252
class LinearBase(CustomOp):
253
    """Base linear layer.
254
255
256
257
258
259

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
260
        quant_config: Quantization configure.
261
        prefix: Prefix for parameter names.
262
        return_bias: If true, return bias together with outputs in forward pass.
263
        disable_tp: If true, tensor parallelism will be disabled for this layer.
264
265
266
267
268
269
270
271
    """

    def __init__(
        self,
        input_size: int,
        output_size: int,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
272
        quant_config: Optional[QuantizationConfig] = None,
273
        prefix: str = "",
274
275
        *,
        return_bias: bool = True,
276
        disable_tp: bool = False,
277
278
279
280
281
282
283
284
285
286
    ):
        super().__init__()

        # Keep input parameters
        self.input_size = input_size
        self.output_size = output_size
        self.skip_bias_add = skip_bias_add
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype
287
288
        self.quant_config = quant_config
        self.prefix = prefix
289
        if quant_config is None:
290
291
            self.quant_method: Optional[
                QuantizeMethodBase] = UnquantizedLinearMethod()
292
        else:
293
294
            self.quant_method = quant_config.get_quant_method(self,
                                                              prefix=prefix)
295
        self.return_bias = return_bias
296
297
298
299
300
301
        self.disable_tp = disable_tp
        self.tp_rank = (get_tensor_model_parallel_rank()
                        if not disable_tp else 0)
        self.tp_size = (get_tensor_model_parallel_world_size()
                        if not disable_tp else 1)

302
    def update_param_tp_status(self):
303
304
305
306
        for param in self.parameters():
            if isinstance(param, BasevLLMParameter):
                param.tp_rank = self.tp_rank
                param.tp_size = self.tp_size
307
308


309
@CustomOp.register("replicated_linear")
310
311
312
313
314
315
316
317
318
319
class ReplicatedLinear(LinearBase):
    """Replicated linear layer.

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        bias: If true, add bias.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
        quant_config: Quantization configure.
320
321
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
322
        return_bias: If true, return bias together with outputs in forward pass.
323
        disable_tp: Take no effect for replicated linear layers.
324
325
    """

326
327
328
329
330
331
332
333
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
334
        eps: Optional[float] = 1e-6,
335
336
337
        prefix: str = "",
        *,
        return_bias: bool = True,
338
        disable_tp: bool = False,
339
    ):
340
341
342
343
344
        super().__init__(input_size,
                         output_size,
                         skip_bias_add,
                         params_dtype,
                         quant_config,
345
                         prefix=prefix,
346
347
                         return_bias=return_bias,
                         disable_tp=disable_tp)
348
        self.eps = eps
349

350
351
        # All the linear layer supports quant method.
        assert self.quant_method is not None
352
        self.quant_method.create_weights(self,
353
                                         self.input_size, [self.output_size],
354
355
356
                                         self.input_size,
                                         self.output_size,
                                         self.params_dtype,
357
                                         weight_loader=self.weight_loader)
358

359
360
        if bias:
            self.bias = Parameter(
361
                torch.empty(self.output_size, dtype=self.params_dtype))
362
363
364
365
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
366
367
368
        else:
            self.register_parameter("bias", None)

369
370
371
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        # If the weight on disk does not have a shape, give it one
        # (such scales for AutoFp8).
372
373
374
375
376
377
378
379
380
381
382
        # Special case for GGUF

        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
            param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)

383
384
385
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)

386
387
388
389
        is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
        if envs.VLLM_USE_NN and not is_quantization:
            loaded_weight = loaded_weight.t()
            
390
391
392
        assert param.size() == loaded_weight.size(), (
            f"Tried to load weights of size {loaded_weight.size()}"
            f"to a parameter of size {param.size()}")
393
394
        param.data.copy_(loaded_weight)

395
    def forward(
396
397
398
399
400
401
        self, 
        input_: torch.Tensor,
        rms_weight: Optional[torch.Tensor] = None,
        residual: Optional[torch.Tensor] = None,
        quant_args: Optional[list] = None,
        update_hd: Optional[bool] = True
402
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
        if envs.USE_FUSED_RMS_QUANT and (rms_weight is not None or quant_args is not None):
            if quant_args is not None:
                input_quant_args = quant_args
            
                bias = self.bias if not self.skip_bias_add else None
                assert self.quant_method is not None
                output = self.quant_method.apply(self, input_, bias, input_quant_args)
                output_bias = self.bias if self.skip_bias_add else None
                if not self.return_bias:
                    return output
                return output, output_bias

            else:
                i_q, _scales = lm_faster_rmsquant(input=input_,
                                                  rms_weight=rms_weight,
                                                  epsilon=self.eps,
                                                  quant_dtype=torch.int8,
                                                  residual=residual,
                                                  update_input=update_hd
                                                  )
            
                new_residual = residual
                input_quant_args = [i_q, _scales]
                
                bias = self.bias if not self.skip_bias_add else None
                assert self.quant_method is not None
                output = self.quant_method.apply(self, input_, bias, input_quant_args)
                output_bias = self.bias if self.skip_bias_add else None
                if not self.return_bias:
                    return output
                return output, new_residual, output_bias, input_quant_args

        else:
            bias = self.bias if not self.skip_bias_add else None
            assert self.quant_method is not None
            output = self.quant_method.apply(self, input_, bias)
            output_bias = self.bias if self.skip_bias_add else None
            if not self.return_bias:
                return output
            return output, output_bias
443

444
445
446
447
448
449
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        return s

450

451
@CustomOp.register("column_parallel_linear")
452
class ColumnParallelLinear(LinearBase):
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
    """Linear layer with column parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its second dimension as A = [A_1, ..., A_p].

    Args:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y_i = XA_i
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
469
        quant_config: Quantization configure.
James Fleming's avatar
James Fleming committed
470
471
        output_sizes: list of output sizes packed into one output, like for QKV
                       the list would be size 3.
472
        prefix: The name of the layer in the state dict, including all parents
473
474
475
                        (e.g. model.layers.0.qkv_proj)
        return_bias: If true, return bias together with outputs in forward pass.
        disable_tp: If true, weights matrix won't be sharded through tp rank.
476
477
    """

478
479
480
481
482
483
484
485
486
487
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
        output_sizes: Optional[list[int]] = None,
488
        eps: Optional[float] = 1e-6,
489
490
491
        prefix: str = "",
        *,
        return_bias: bool = True,
492
        disable_tp: bool = False,
493
    ):
494
        # Divide the weight matrix along the last dimension.
495
496
497
498
        self.tp_rank = (get_tensor_model_parallel_rank()
                        if not disable_tp else 0)
        self.tp_size = (get_tensor_model_parallel_world_size()
                        if not disable_tp else 1)
499
500
        self.input_size_per_partition = input_size
        self.output_size_per_partition = divide(output_size, self.tp_size)
501
502
503
504
        self.output_partition_sizes = [self.output_size_per_partition]
        # If QKV or MergedColumn, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = [
505
                divide(output_size, self.tp_size)
506
507
508
                for output_size in self.output_sizes
            ]

509
510
511
512
513
514
        super().__init__(input_size,
                         output_size,
                         skip_bias_add,
                         params_dtype,
                         quant_config,
                         prefix,
515
516
                         return_bias=return_bias,
                         disable_tp=disable_tp)
517

518
        self.eps = eps
519
520
        self.gather_output = gather_output

James Fleming's avatar
James Fleming committed
521
522
        if output_sizes is None:
            output_sizes = [output_size]
523

524
        assert self.quant_method is not None
525
526
        self.quant_method.create_weights(
            layer=self,
527
            input_size_per_partition=self.input_size_per_partition,
528
529
530
531
            output_partition_sizes=self.output_partition_sizes,
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
532
533
            weight_loader=(
                self.weight_loader_v2 if self.quant_method.__class__.__name__
534
                in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
535
536
537
538
539
540
541
542
543
544
        if bias:
            self.bias = Parameter(
                torch.empty(self.output_size_per_partition,
                            dtype=params_dtype))
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
        else:
            self.register_parameter("bias", None)
545
        self.update_param_tp_status()
546

547
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
548

549
        output_dim = getattr(param, "output_dim", None)
550

551
552
553
554
555
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
556
        is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
557

558
559
560
561
562
563
564
565
        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
566
567
            final_shape = list(loaded_weight.shape)
            if output_dim is not None:
568
569
570
                assert final_shape[output_dim] % self.tp_size == 0
                final_shape[output_dim] = (final_shape[output_dim] //
                                           self.tp_size)
571
            param.materialize(final_shape, dtype=loaded_weight.dtype)
572

573
        param_data = param.data
574
        if output_dim is not None and not is_sharded_weight:
575
576
577
578
            if not envs.VLLM_USE_NN or len(param_data.shape)==1 or is_quantization:
                shard_size = param_data.shape[output_dim] 
            else:
                shard_size = param_data.shape[int(not(output_dim))]
579
            start_idx = self.tp_rank * shard_size
580

581
582
            loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                 shard_size)
583
584
585
586
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)
587

588
589
590
        if envs.VLLM_USE_NN and not is_quantization:
            loaded_weight = loaded_weight.t()
            
591
592
593
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

594
595
    def weight_loader_v2(self, param: BasevLLMParameter,
                         loaded_weight: torch.Tensor):
596
597
598
599
600
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)
601
602
        param.load_column_parallel_weight(loaded_weight=loaded_weight)

603
    def forward(
604
605
606
607
        self, input_,
        rms_weight: Optional[torch.Tensor] = None,
        residual: Optional[torch.Tensor] = None,
        update_hd: Optional[bool] = True
608
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
        if envs.USE_FUSED_RMS_QUANT and rms_weight is not None:
            input_quant_args = None
            assert rms_weight is not None 
            i_q, _scales = lm_faster_rmsquant(input=input_,
                                        rms_weight=rms_weight,
                                        epsilon=self.eps,
                                        quant_dtype=torch.int8,
                                        residual=residual,
                                        update_input=update_hd)
            new_residual = residual
            input_quant_args = [i_q, _scales]
        
            bias = self.bias if not self.skip_bias_add else None
            
            assert self.quant_method is not None
            output_parallel = self.quant_method.apply(self, input_, bias, input_quant_args)
            if self.gather_output and self.tp_size > 1:
                output = tensor_model_parallel_all_gather(output_parallel)
            else:
                output = output_parallel
            output_bias = self.bias if self.skip_bias_add else None
            if not self.return_bias:
                return output
            return output, new_residual, output_bias
633
        else:
634
635
636
637
638
639
640
641
642
643
644
645
646
            bias = self.bias if not self.skip_bias_add else None
            # Matrix multiply.
            assert self.quant_method is not None
            output_parallel = self.quant_method.apply(self, input_, bias)
            if self.gather_output and self.tp_size > 1:
                # All-gather across the partitions.
                output = tensor_model_parallel_all_gather(output_parallel)
            else:
                output = output_parallel
            output_bias = self.bias if self.skip_bias_add else None
            if not self.return_bias:
                return output
            return output, output_bias
647

648
649
650
651
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size_per_partition}"
        s += f", bias={self.bias is not None}"
652
        s += f", tp_size={self.tp_size}"
653
654
655
        s += f", gather_output={self.gather_output}"
        return s

656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674

class MergedColumnParallelLinear(ColumnParallelLinear):
    """Packed linear layers with column parallelism.

    Similar to ColumnParallelLinear, but the weight matrix is concatenated
    along the output dimension. When the weight matrix is loaded, the
    different partitions are sharded separately.

    Args:
        input_size: input dimension of the linear layer.
        output_sizes: list of output dimensions of the linear layer.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make the output
                       available to all GPUs, otherwise, every GPU will have
                       its own output.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
675
        quant_config: Quantization configure.
676
677
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
678
        return_bias: If true, return bias together with outputs in forward pass.
679
680
        disable_tp: If true, all weights matrix won't be sharded, this layer
                    will be treated as a "Replicated" MergedLinear.
681
682
    """

683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
    def forward(
        self, input_,
        rms_weight: Optional[torch.Tensor] = None,
        residual: Optional[torch.Tensor] = None,
        update_hd: Optional[bool] = True
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
        if envs.USE_FUSED_RMS_QUANT and rms_weight is not None:
            input_quant_args = None
            assert residual is not None and rms_weight is not None 
            i_q, _scales = lm_faster_rmsquant(input=input_,
                                        rms_weight=rms_weight,
                                        epsilon=self.eps,
                                        quant_dtype=torch.int8,
                                        residual=residual,
                                        update_input=update_hd)
            
            new_residual = residual
            input_quant_args = [i_q, _scales]
            
            
            bias = self.bias if not self.skip_bias_add else None
            assert self.quant_method is not None
            output_parallel = self.quant_method.apply(self, input_, bias, input_quant_args)
            
            if self.gather_output:
                # All-gather across the partitions.
                output = tensor_model_parallel_all_gather(output_parallel)
            else:
                output = output_parallel
            output_bias = self.bias if self.skip_bias_add else None
            if not self.return_bias:
                return output
            return output, new_residual, output_bias
        else: # not USE_FUSED_RMS_QUANT
            bias = self.bias if not self.skip_bias_add else None

            assert self.quant_method is not None
            output_parallel = self.quant_method.apply(self, input_, bias)
            if self.gather_output and self.tp_size > 1:
                # All-gather across the partitions.
                output = tensor_model_parallel_all_gather(output_parallel)
            else:
                output = output_parallel
            output_bias = self.bias if self.skip_bias_add else None
            if not self.return_bias:
                return output
            return output, output_bias

731
732
733
734
735
736
737
738
739
    def __init__(
        self,
        input_size: int,
        output_sizes: list[int],
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
740
        eps: Optional[float] = 1e-6,
741
742
743
        prefix: str = "",
        *,
        return_bias: bool = True,
744
        disable_tp: bool = False,
745
    ):
746
        self.eps = eps
747
        self.output_sizes = output_sizes
748
749
750
751
        self.tp_size = (get_tensor_model_parallel_world_size()
                        if not disable_tp else 1)
        self.tp_rank = (get_tensor_model_parallel_rank()
                        if not disable_tp else 0)
752
753
754

        assert all(output_size % self.tp_size == 0
                   for output_size in output_sizes)
755
756
757
758
759
760
        super().__init__(input_size=input_size,
                         output_size=sum(output_sizes),
                         bias=bias,
                         gather_output=gather_output,
                         skip_bias_add=skip_bias_add,
                         params_dtype=params_dtype,
761
                         quant_config=quant_config,
762
                         prefix=prefix,
763
764
                         return_bias=return_bias,
                         disable_tp=disable_tp)
765
766
767
768
769

    def weight_loader(self,
                      param: Parameter,
                      loaded_weight: torch.Tensor,
                      loaded_shard_id: Optional[int] = None):
James Fleming's avatar
James Fleming committed
770

771
772
773
774
775
        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
776
777
778
779
780
781
782
783
            if loaded_shard_id is not None:
                param.data[loaded_shard_id].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
                param.shard_weight_type = {
                    i: loaded_weight.item()
                    for i, _ in enumerate(self.output_sizes)
                }
784
785
            return

786
        if is_gguf_weight:
787

788
            output_dim = getattr(param, "output_dim", None)
789
790
            shard_size = loaded_weight.size(output_dim) // self.tp_size
            start_idx = self.tp_rank * shard_size
791

792
793
794
795
796
797
798
            if loaded_shard_id is not None:
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
799

800
801
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
802
803
        # Special case for per-tensor scale to load scalar into fused array.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
804
        is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
805

806
        if loaded_shard_id is None:
807
808
            # Loaded weight is already fused on disk (mlp).
            # (e.g., Phi-3's gate_up_proj).
809
            if output_dim is None:
810
                if needs_scalar_to_array:
811
812
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
                        param_data, loaded_weight, 0)
813

814
815
816
817
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            current_shard_offset = 0
818
819
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)
820
            shard_offsets: list[tuple[int, int, int]] = []
821
822
823
824
825
            for i, output_size in enumerate(self.output_sizes):
                shard_offsets.append((i, current_shard_offset, output_size))
                current_shard_offset += output_size
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
826
                # Special case for Quantization.
827
828
829
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
830
831
                    shard_size = shard_size // param.packed_factor
                    shard_offset = shard_offset // param.packed_factor
832
                    # Special case for Marlin.
833
834
835
                    shard_size, shard_offset = adjust_marlin_shard(
                        param, shard_size, shard_offset)

836
837
838
                shard_size, shard_offset = adjust_bitblas_shard(
                    param, shard_size, shard_offset)

839
                if use_bitsandbytes_4bit:
840
841
842
843
844
845
846
847
848
                    index = list(itertools.accumulate([0] + self.output_sizes))
                    orig_offsets = {
                        str(i): (index[i], size)
                        for i, size in enumerate(self.output_sizes)
                    }
                    orig_offsets["total"] = (self.output_size, 0)
                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
                        param, orig_offsets, str(shard_id))

849
850
851
852
853
854
855
                loaded_weight_shard = loaded_weight.narrow(
                    output_dim, shard_offset, shard_size)
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id < len(self.output_sizes)
        if output_dim is not None:
856
857
858
            shard_offset = (sum(self.output_sizes[:loaded_shard_id]) //
                            self.tp_size)
            shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
859
            # Special case for quantization.
860
861
862
863
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
864
865
                shard_size = shard_size // param.packed_factor
                shard_offset = shard_offset // param.packed_factor
866
                # Special case for Marlin.
867
868
                shard_size, shard_offset = adjust_marlin_shard(
                    param, shard_size, shard_offset)
869
870
            shard_size, shard_offset = adjust_bitblas_shard(
                param, shard_size, shard_offset)
gaoqiong's avatar
gaoqiong committed
871

872
873
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)
874
875
876
877
878
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

879
            if use_bitsandbytes_4bit:
880
881
882
                shard_size = loaded_weight.shape[output_dim]
                shard_offset = loaded_weight.shape[output_dim] * \
                    loaded_shard_id
883
884
885
886
887
                    
            if not envs.VLLM_USE_NN or is_quantization:
                param_data = param_data.narrow(output_dim, shard_offset, shard_size)
            else:
                param_data = param_data.narrow(int(not(output_dim)), shard_offset, shard_size)
888

889
            start_idx = self.tp_rank * shard_size
890
            if not is_sharded_weight:
891
892
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)
893
894
895
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
896
897
                param_data, loaded_weight, loaded_shard_id)

898
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
899
900
901
902
903
904
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "MergedColumnParallelLinear, assume the weight is "
                    "the same for all partitions.")
905

906
907
908
        if envs.VLLM_USE_NN and not is_quantization:
            loaded_weight = loaded_weight.t()
            
gaoqiong's avatar
gaoqiong committed
909
910
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)
911

912
913
914
915
916
917
918
919
920
921
922
923
924
    def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
                                           loaded_weight: torch.Tensor):
        """
        Handle special case for models where MLP layers are already
        fused on disk. In this case, we have no shard id. This function
        determmines the shard id by splitting these layers and then calls
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """

        current_shard_offset = 0
925
        shard_offsets: list[tuple[int, int, int]] = []
926
927
928
929
930
931
932
933
        for i, output_size in enumerate(self.output_sizes):
            shard_offsets.append((i, current_shard_offset, output_size))
            current_shard_offset += output_size

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
934
935
            if isinstance(param, (PackedColumnParameter, PackedvLLMParameter
                                  )) and param.packed_dim == param.output_dim:
936
937
                shard_size, shard_offset = \
                    param.adjust_shard_indexes_for_packing(
938
939
940
941
942
943
944
945
946
947
948
949
                    shard_size=shard_size, shard_offset=shard_offset)

            loaded_weight_shard = loaded_weight.narrow(param.output_dim,
                                                       shard_offset,
                                                       shard_size)
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

    def weight_loader_v2(self,
                         param: BasevLLMParameter,
                         loaded_weight: torch.Tensor,
                         loaded_shard_id: Optional[int] = None):
        if loaded_shard_id is None:
950
951
952
953
            if isinstance(param, PerTensorScaleParameter):
                param.load_merged_column_weight(loaded_weight=loaded_weight,
                                                shard_id=0)
                return
954
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
955
                param.load_merged_column_weight(loaded_weight=loaded_weight)
956
                return
957
            # TODO: @dsikka - move to parameter.py
958
959
960
961
962
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id < len(self.output_sizes)

963
964
965
        if isinstance(param, BlockQuantScaleParameter):
            from vllm.model_executor.layers.quantization.fp8 import (
                Fp8LinearMethod, Fp8MoEMethod)
966
967
968
            
            from vllm.model_executor.layers.quantization.blockwise_int8 import (
                BlockInt8LinearMethod, BlockInt8MoEMethod)
969
970
            assert self.quant_method is not None
            assert isinstance(self.quant_method,
971
                              (Fp8LinearMethod, Fp8MoEMethod, BlockInt8LinearMethod, BlockInt8MoEMethod))
972
973
974
975
976
            weight_block_size = self.quant_method.quant_config.weight_block_size
            assert weight_block_size is not None
            block_n, _ = weight_block_size[0], weight_block_size[1]
            shard_offset = (
                (sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) //
977
                block_n) // self.tp_size
978
            shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) //
979
                          block_n // self.tp_size)
980
        else:
981
982
983
            shard_offset = sum(
                self.output_sizes[:loaded_shard_id]) // self.tp_size
            shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
984
985
986
987

        param.load_merged_column_weight(loaded_weight=loaded_weight,
                                        shard_id=loaded_shard_id,
                                        shard_offset=shard_offset,
988
989
                                        shard_size=shard_size,
                                        tp_rank=self.tp_rank)
990

991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012

class QKVParallelLinear(ColumnParallelLinear):
    """Linear layers for the attention's QKV transformation.

    Linear layers for the linear transformation of the query, key, and value
    vectors in the attention layer. The weight matrix is concatenated along
    the output dimension. The layer is parallelized along the head dimension.
    When the number of key/value heads is smaller than the number of query
    heads (e.g., multi-query/grouped-query attention), the key/value head may
    be replicated while the query heads are partitioned.

    Args:
        hidden_size: input hidden state size of the transformer.
        head_size: size of each attention head.
        total_num_heads: total number of attention query heads.
        total_num_kv_heads: total number of attention key/value heads. If
                            None, assume total_num_kv_heads = total_num_heads.
        bias: If true, add bias.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
1013
        quant_config: Quantization configure.
1014
1015
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
1016
        return_bias: If true, return bias together with outputs in forward pass.
1017
        disable_tp: If true, weights matrix won't be sharded through tp rank.
1018
1019
    """

1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
    def __init__(
        self,
        hidden_size: int,
        head_size: int,
        total_num_heads: int,
        total_num_kv_heads: Optional[int] = None,
        bias: bool = True,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        *,
        return_bias: bool = True,
1033
        disable_tp: bool = False,
1034
    ):
1035
1036
1037
1038
1039
1040
1041
        self.hidden_size = hidden_size
        self.head_size = head_size
        self.total_num_heads = total_num_heads
        if total_num_kv_heads is None:
            total_num_kv_heads = total_num_heads
        self.total_num_kv_heads = total_num_kv_heads
        # Divide the weight matrix along the last dimension.
1042
1043
        tp_size = (get_tensor_model_parallel_world_size()
                   if not disable_tp else 1)
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
        self.num_heads = divide(self.total_num_heads, tp_size)
        if tp_size >= self.total_num_kv_heads:
            self.num_kv_heads = 1
            self.num_kv_head_replicas = divide(tp_size,
                                               self.total_num_kv_heads)
        else:
            self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
            self.num_kv_head_replicas = 1
        input_size = self.hidden_size
        output_size = (self.num_heads +
                       2 * self.num_kv_heads) * tp_size * self.head_size
1055
1056
1057
1058
        self.output_sizes = [
            self.num_heads * self.head_size * tp_size,  # q_proj
            self.num_kv_heads * self.head_size * tp_size,  # k_proj
            self.num_kv_heads * self.head_size * tp_size,  # v_proj 
James Fleming's avatar
James Fleming committed
1059
        ]
gaoqiong's avatar
gaoqiong committed
1060

1061
1062
1063
1064
1065
1066
        super().__init__(input_size=input_size,
                         output_size=output_size,
                         bias=bias,
                         gather_output=False,
                         skip_bias_add=skip_bias_add,
                         params_dtype=params_dtype,
1067
                         quant_config=quant_config,
1068
                         prefix=prefix,
1069
1070
                         return_bias=return_bias,
                         disable_tp=disable_tp)
1071

1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
    def _get_shard_offset_mapping(self, loaded_shard_id: str):
        shard_offset_mapping = {
            "q": 0,
            "k": self.num_heads * self.head_size,
            "v": (self.num_heads + self.num_kv_heads) * self.head_size,
            "total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size
        }
        return shard_offset_mapping.get(loaded_shard_id)

    def _get_shard_size_mapping(self, loaded_shard_id: str):
        shard_size_mapping = {
            "q": self.num_heads * self.head_size,
            "k": self.num_kv_heads * self.head_size,
            "v": self.num_kv_heads * self.head_size,
        }
        return shard_size_mapping.get(loaded_shard_id)

    def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
                                           loaded_weight: torch.Tensor):
        """
        Handle special case for models where QKV layers are already 
        fused on disk. In this case, we have no shard id. This function
        determmines the shard id by splitting these layers and then calls
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """
        shard_offsets = [
            # (shard_id, shard_offset, shard_size)
            ("q", 0, self.total_num_heads * self.head_size),
            ("k", self.total_num_heads * self.head_size,
             self.total_num_kv_heads * self.head_size),
            ("v",
             (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
             self.total_num_kv_heads * self.head_size),
        ]

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
1114
1115
            if isinstance(param, (PackedColumnParameter, PackedvLLMParameter
                                  )) and param.packed_dim == param.output_dim:
1116
1117
                shard_size, shard_offset = \
                    param.adjust_shard_indexes_for_packing(
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
                    shard_size=shard_size, shard_offset=shard_offset)

            loaded_weight_shard = loaded_weight.narrow(param.output_dim,
                                                       shard_offset,
                                                       shard_size)
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

    def weight_loader_v2(self,
                         param: BasevLLMParameter,
                         loaded_weight: torch.Tensor,
                         loaded_shard_id: Optional[str] = None):
        if loaded_shard_id is None:  # special case for certain models
1130
            if isinstance(param, PerTensorScaleParameter):
1131
1132
1133
                param.load_qkv_weight(loaded_weight=loaded_weight,
                                      shard_id=0,
                                      tp_rank=self.tp_rank)
1134
                return
1135
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
1136
1137
                param.load_qkv_weight(loaded_weight=loaded_weight,
                                      tp_rank=self.tp_rank)
1138
                return
1139
            # TODO: @dsikka - move to parameter.py
1140
1141
1142
1143
1144
1145
1146
1147
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id in ["q", "k", "v"]

        shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
        shard_size = self._get_shard_size_mapping(loaded_shard_id)

1148
1149
1150
1151
1152
1153
1154
1155
1156
        # Note(simon): This is needed for Qwen3's fp8 quantization.
        if isinstance(param, BlockQuantScaleParameter):
            assert self.quant_method is not None
            assert hasattr(self.quant_method, "quant_config")
            weight_block_size = self.quant_method.quant_config.weight_block_size
            block_n, _ = weight_block_size[0], weight_block_size[1]
            shard_offset = (shard_offset + block_n - 1) // block_n
            shard_size = (shard_size + block_n - 1) // block_n

1157
1158
1159
1160
        param.load_qkv_weight(loaded_weight=loaded_weight,
                              num_heads=self.num_kv_head_replicas,
                              shard_id=loaded_shard_id,
                              shard_offset=shard_offset,
1161
1162
                              shard_size=shard_size,
                              tp_rank=self.tp_rank)
1163

1164
1165
1166
1167
    def weight_loader(self,
                      param: Parameter,
                      loaded_weight: torch.Tensor,
                      loaded_shard_id: Optional[str] = None):
1168
1169
1170
1171
1172

        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
1173
        if is_gguf_weight_type:
1174
            idx_map = {"q": 0, "k": 1, "v": 2}
1175
1176
1177
1178
1179
1180
1181
1182
            if loaded_shard_id is not None:
                param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
                param.shard_weight_type = {
                    k: loaded_weight.item()
                    for k in idx_map
                }
1183
1184
            return

1185
1186
        if is_gguf_weight:
            output_dim = getattr(param, "output_dim", None)
1187
1188
            shard_size = loaded_weight.size(output_dim) // self.tp_size
            start_idx = self.tp_rank * shard_size
1189

1190
1191
1192
1193
1194
1195
1196
            if loaded_shard_id is not None:
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
1197

1198
1199
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
1200

1201
1202
        # Special case for per-tensor scales in fused case.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
1203
        is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
1204

1205
        if loaded_shard_id is None:
1206
1207
            # Loaded weight is already fused on disk (qkv).
            # (e.g., Phi-3's qkv_proj).
1208
            if output_dim is None:
1209
                if needs_scalar_to_array:
1210
1211
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
                        param_data, loaded_weight, 0)
1212

1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            shard_offsets = [
                # (shard_id, shard_offset, shard_size)
                ("q", 0, self.total_num_heads * self.head_size),
                ("k", self.total_num_heads * self.head_size,
                 self.total_num_kv_heads * self.head_size),
                ("v", (self.total_num_heads + self.total_num_kv_heads) *
                 self.head_size, self.total_num_kv_heads * self.head_size),
            ]
1224
1225
1226
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)

1227
1228
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
1229
                # Special case for Quantized Weights.
1230
1231
1232
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
1233
1234
                    shard_size = shard_size // param.packed_factor
                    shard_offset = shard_offset // param.packed_factor
1235

1236
                    # Special case for Marlin.
1237
1238
1239
                    shard_size, shard_offset = adjust_marlin_shard(
                        param, shard_size, shard_offset)

1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
                if use_bitsandbytes_4bit:
                    orig_qkv_offsets = {
                        "q": (0, self.total_num_heads * self.head_size),
                        "k": (self.total_num_heads * self.head_size,
                              self.total_num_kv_heads * self.head_size),
                        "v":
                        ((self.total_num_heads + self.total_num_kv_heads) *
                         self.head_size,
                         self.total_num_kv_heads * self.head_size),
                        "total":
                        ((self.total_num_heads + 2 * self.total_num_kv_heads) *
                         self.head_size, 0)
                    }

                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
                        param, orig_qkv_offsets, shard_id)

1257
1258
1259
1260
1261
1262
                loaded_weight_shard = loaded_weight.narrow(
                    output_dim, shard_offset, shard_size)
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id in ["q", "k", "v"]
1263
1264

        # If output dim is defined, use the default loading process.
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
        if output_dim is not None:
            if loaded_shard_id == "q":
                shard_offset = 0
                shard_size = self.num_heads * self.head_size
            elif loaded_shard_id == "k":
                shard_offset = self.num_heads * self.head_size
                shard_size = self.num_kv_heads * self.head_size
            elif loaded_shard_id == "v":
                shard_offset = (self.num_heads +
                                self.num_kv_heads) * self.head_size
                shard_size = self.num_kv_heads * self.head_size
1276
            # Special case for Quantized Weights.
1277
1278
1279
1280
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
1281
1282
                shard_size = shard_size // param.packed_factor
                shard_offset = shard_offset // param.packed_factor
1283

1284
                # Special case for Marlin.
1285
1286
1287
                shard_size, shard_offset = adjust_marlin_shard(
                    param, shard_size, shard_offset)

1288
1289
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)
1290
1291
1292
1293
1294
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

1295
            if use_bitsandbytes_4bit:
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
                orig_qkv_offsets = {
                    "q": (0, self.num_heads * self.head_size),
                    "k": (self.num_heads * self.head_size,
                          self.num_kv_heads * self.head_size),
                    "v":
                    ((self.num_heads + self.num_kv_heads) * self.head_size,
                     self.num_kv_heads * self.head_size),
                    "total":
                    ((self.num_heads + 2 * self.num_kv_heads) * self.head_size,
                     0)
                }
1307
                shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
1308
                    param, orig_qkv_offsets, loaded_shard_id)
gaoqiong's avatar
gaoqiong committed
1309

1310
1311
1312
1313
1314
1315
1316
            if not envs.VLLM_USE_NN or len(param_data.shape)==1 or is_quantization:
                param_data = param_data.narrow(output_dim, shard_offset,
                                               shard_size)
            else:
                param_data = param_data.narrow(int(not(output_dim)), shard_offset,
                                               shard_size)
                
zhuwenwen's avatar
zhuwenwen committed
1317
            if loaded_shard_id == "q":
1318
                shard_id = self.tp_rank
1319
            else:
1320
                shard_id = self.tp_rank // self.num_kv_head_replicas
1321
            start_idx = shard_id * shard_size
1322

1323
            if not is_sharded_weight:
1324
1325
1326
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)

1327
1328
1329
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
1330
                param_data, loaded_weight, loaded_shard_id)
1331
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
1332
1333
1334
1335
1336
1337
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "QKVParallelLinear, assume the weight is the same "
                    "for all partitions.")
gaoqiong's avatar
gaoqiong committed
1338

1339
1340
1341
        if envs.VLLM_USE_NN and not is_quantization:
            loaded_weight = loaded_weight.t()
            
gaoqiong's avatar
gaoqiong committed
1342
1343
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)
1344
1345


1346
@CustomOp.register("row_parallel_linear")
1347
class RowParallelLinear(LinearBase):
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
    """Linear layer with row parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its first dimension and X along its second dimension as:
               -   -
              | A_1 |
              | .   |
          A = | .   |        X = [X_1, ..., X_p]
              | .   |
              | A_p |
               -   -
    Arguments:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias. Note that bias is not parallelized.
        input_is_parallel: If true, we assume that the input is already
                           split across the GPUs and we do not split
                           again.
        skip_bias_add: This was added to enable performance optimization where
                       bias can be fused with other element-wise operations.
                       We skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
1370
1371
1372
        reduce_results: If true, call all-reduce on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y = X_iA_i
1373
        quant_config: Quantization configure.
1374
1375
1376
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.down_proj)
        return_bias: If true, return bias together with outputs in forward pass.
1377
        disable_tp: If true, weights matrix won't be sharded through tp rank.
1378
1379
    """

1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        input_is_parallel: bool = True,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        reduce_results: bool = True,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        *,
        return_bias: bool = True,
1393
        disable_tp: bool = False,
1394
    ):
1395
        # Divide the weight matrix along the first dimension.
1396
1397
1398
1399
        self.tp_rank = (get_tensor_model_parallel_rank()
                        if not disable_tp else 0)
        self.tp_size = (get_tensor_model_parallel_world_size()
                        if not disable_tp else 1)
1400
1401
1402
1403
        self.input_size_per_partition = divide(input_size, self.tp_size)
        self.output_size_per_partition = output_size
        self.output_partition_sizes = [output_size]

1404
1405
1406
1407
1408
1409
        super().__init__(input_size,
                         output_size,
                         skip_bias_add,
                         params_dtype,
                         quant_config,
                         prefix,
1410
1411
                         return_bias=return_bias,
                         disable_tp=disable_tp)
1412

1413
1414
1415
        self.input_is_parallel = input_is_parallel
        self.reduce_results = reduce_results

1416
        assert self.quant_method is not None
1417
1418
1419
        self.quant_method.create_weights(
            layer=self,
            input_size_per_partition=self.input_size_per_partition,
1420
            output_partition_sizes=self.output_partition_sizes,
1421
1422
1423
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
1424
1425
            weight_loader=(
                self.weight_loader_v2 if self.quant_method.__class__.__name__
1426
                in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
1427
1428
1429
1430
1431
1432
        if not reduce_results and (bias and not skip_bias_add):
            raise ValueError("When not reduce the results, adding bias to the "
                             "results can lead to incorrect results")

        if bias:
            self.bias = Parameter(
1433
                torch.empty(self.output_size, dtype=params_dtype))
1434
1435
1436
1437
1438
1439
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
        else:
            self.register_parameter("bias", None)
1440

1441
        from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce
1442
        self.tbo_all_reduce = tbo_all_reduce
1443

1444
        self.update_param_tp_status()
1445
1446
1447

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        input_dim = getattr(param, "input_dim", None)
1448
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1449
1450
1451
1452
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
1453
1454
1455
1456
1457
1458
1459
1460
1461

        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
1462
1463
            weight_shape = list(loaded_weight.shape)
            if input_dim:
1464
1465
                weight_shape[input_dim] = (weight_shape[input_dim] //
                                           self.tp_size)
1466
            param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
1467
1468
            
        is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
1469

1470
        param_data = param.data
1471
        if input_dim is not None and not is_sharded_weight:
1472
1473
1474
1475
            if not envs.VLLM_USE_NN or is_quantization:
                shard_size = param_data.shape[input_dim]
            else:
                shard_size = param_data.shape[int(not(input_dim))]
1476
            start_idx = self.tp_rank * shard_size
1477
1478
            loaded_weight = loaded_weight.narrow(input_dim, start_idx,
                                                 shard_size)
1479

1480
1481
1482
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
1483
1484
            loaded_weight = loaded_weight.reshape(1)

1485
1486
1487
        if envs.VLLM_USE_NN and not is_quantization:
            loaded_weight = loaded_weight.t()
            
1488
1489
1490
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

1491
1492
    def weight_loader_v2(self, param: BasevLLMParameter,
                         loaded_weight: torch.Tensor):
1493
1494
1495
1496
1497
1498
1499

        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)

1500
1501
        param.load_row_parallel_weight(loaded_weight=loaded_weight)

1502
1503
1504
    def forward(
        self, input_
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
1505
1506
1507
1508
1509
        if self.input_is_parallel:
            input_parallel = input_
        else:
            splitted_input = split_tensor_along_last_dim(
                input_, num_partitions=self.tp_size)
1510
            input_parallel = splitted_input[self.tp_rank].contiguous()
1511
1512

        # Matrix multiply.
1513
        assert self.quant_method is not None
1514
1515
1516
1517
1518
1519
        # Only fuse bias add into GEMM for rank 0 (this ensures that
        # bias will not get added more than once in TP>1 case)
        bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
        output_parallel = self.quant_method.apply(self,
                                                  input_parallel,
                                                  bias=bias_)
1520
        if self.reduce_results and self.tp_size > 1:
1521
            if envs.VLLM_ENABLE_TBO:
1522
1523
1524
                output = self.tbo_all_reduce(output_parallel)
            else:
                output = tensor_model_parallel_all_reduce(output_parallel)
1525
        else:
1526
1527
1528
            output = output_parallel

        output_bias = self.bias if self.skip_bias_add else None
1529

1530
1531
        if not self.return_bias:
            return output
1532
        return output, output_bias
1533
1534

    def extra_repr(self) -> str:
1535
        s = f"in_features={self.input_size_per_partition}"
1536
1537
1538
1539
1540
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={self.tp_size}"
        s += f", reduce_results={self.reduce_results}"
        return s
1541
1542


1543
@CustomOp.register("qkv_cross_parallel_linear")
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
class QKVCrossParallelLinear(LinearBase):
    """Linear layers for efficient cross-attention's QKV transformation.

    Args:
        hidden_size: input hidden state size of the transformer.
        head_size: size of each attention head.
        total_num_heads: total number of attention query heads.
        total_num_kv_heads: total number of attention key/value heads. If
                            None, assume total_num_kv_heads = total_num_heads.
        bias: If true, add bias.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
        quant_config: Quantization configure.
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
    """
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572

    def __init__(self,
                 hidden_size: int,
                 head_size: int,
                 total_num_heads: int,
                 total_num_kv_heads: Optional[int] = None,
                 bias: bool = True,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
        # input_size and output_size are not used, just for alignment
        input_size = hidden_size
        output_size = (total_num_heads + (total_num_kv_heads or 0)) * head_size
        super().__init__(input_size=input_size,
                         output_size=output_size,
                         skip_bias_add=skip_bias_add,
                         params_dtype=params_dtype,
                         quant_config=quant_config,
                         prefix=prefix)

        self.quant_config = quant_config

1585
        # Empty placeholders for loading as a single module.
1586
1587
1588
1589
1590
1591
1592
1593
1594
        placeholder_size = 0
        assert self.quant_method is not None
        self.quant_method.create_weights(self,
                                         placeholder_size, [placeholder_size],
                                         placeholder_size,
                                         placeholder_size,
                                         self.params_dtype,
                                         weight_loader=self.weight_loader)

1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
        # Use a dictionary to avoid submodules parameters auto-registration:
        # drop-in replacement for a `QKVParallelLinear` module.
        self.proj = dict()
        self.proj["q_proj_decoder"] = ColumnParallelLinear(
            input_size=hidden_size,
            output_size=total_num_heads * head_size,
            bias=bias,
            quant_config=quant_config,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            prefix=f"{prefix}.q_proj_decoder")

        self.proj["kv_proj_encoder"] = QKVParallelLinear(
            hidden_size=hidden_size,
            head_size=head_size,
            total_num_heads=0,
            total_num_kv_heads=total_num_kv_heads,
            bias=bias,
            quant_config=quant_config,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            prefix=f"{prefix}.kv_proj_encoder")

        # `kv_proj_encoder.num_kv_heads` accounts for sharding with tp>1.
1619
        self.q_size = self.q_proj_decoder.output_size_per_partition
1620
1621
1622
1623
1624
        self.kv_size = self.kv_proj_encoder.num_kv_heads * head_size

        if bias:
            self.bias = torch.nn.Parameter()
            set_weight_attrs(self.bias, {
1625
                "output_dim": 0,
1626
                "weight_loader": self.weight_loader_v1,
1627
            })
1628
1629
        else:
            self.bias = None
1630

1631
1632
1633
1634
1635
    def process_weights_after_loading(self):
        for layer in self.proj.values():
            if self.quant_method is not None:
                self.quant_method.process_weights_after_loading(layer)

1636
    @property
1637
1638
1639
    def q_proj_decoder(self) -> ColumnParallelLinear:
        layer = self.proj["q_proj_decoder"]
        for name, param in self.named_parameters():
1640
1641
1642
1643
1644
            target_param = getattr(layer, name, None)
            if target_param is not None:
                self.sync_weight_attrs(param,
                                       target_param,
                                       mode="q_proj_decoder")
1645
        return layer
1646
1647

    @property
1648
1649
1650
    def kv_proj_encoder(self) -> QKVParallelLinear:
        layer = self.proj["kv_proj_encoder"]
        for name, param in self.named_parameters():
1651
1652
1653
1654
1655
            target_param = getattr(layer, name, None)
            if target_param is not None:
                self.sync_weight_attrs(param,
                                       target_param,
                                       mode="kv_proj_encoder")
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
        return layer

    def sync_weight_attrs(
        self,
        src_param: nn.Parameter,
        tgt_param: nn.Parameter,
        mode: Literal["q_proj_decoder", "kv_proj_encoder"],
    ):
        missing_attrs_dict = {
            k: getattr(src_param, k)
1666
1667
            for k in (set(vars(src_param).keys()) -
                      set(vars(tgt_param).keys()))
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
        }
        # TODO(Isotr0py): handle bitsandbytes 8bit
        use_bitsandbytes_4bit = getattr(src_param, "use_bitsandbytes_4bit",
                                        False)
        if (missing_attrs_dict and use_bitsandbytes_4bit):
            q_proj_attrs, kv_proj_attrs = left_shift_bitsandbytes_4bit_shard(
                missing_attrs_dict)
            if mode == "q_proj_decoder":
                set_weight_attrs(tgt_param, q_proj_attrs)
            elif mode == "kv_proj_encoder":
                set_weight_attrs(tgt_param, kv_proj_attrs)
        else:
            set_weight_attrs(tgt_param, missing_attrs_dict)
1681

1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
    def _is_same_param(
        self,
        src_param: torch.nn.Parameter,
        map_param: torch.nn.Parameter,
    ) -> bool:
        """Check if two parameters are exactly pointing to same things."""
        # ignore weight_loader because it's always different
        key_to_ignore = ["weight_loader", "_weight_loader"]
        has_same_type_name = type(src_param) is type(map_param)
        src_param_attrs = {
            k: v
            for k, v in src_param.__dict__.items() if k not in key_to_ignore
        }
        map_param_attrs = {
            k: v
            for k, v in map_param.__dict__.items() if k not in key_to_ignore
        }
        has_same_attrs = src_param_attrs == map_param_attrs
        return has_same_type_name and has_same_attrs

    def select_proj_params(
        self,
        layer: nn.Module,
        param: nn.Parameter,
    ) -> nn.Parameter:
        """
        Given the placeholder param, 
        return the corresponding param in the proj layers.
        """
        target_param_list = [
            v for _, v in layer.named_parameters()
            if self._is_same_param(param, v)
        ]
        assert len(target_param_list) == 1
        target_param = target_param_list[0]
        return target_param

    def forward(  # type: ignore[override]
        self,
        decoder_hidden_states: torch.Tensor,
        encoder_hidden_states: torch.Tensor,
    ) -> tuple[torch.Tensor, ...]:
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
        q, _ = self.q_proj_decoder(decoder_hidden_states)
        if encoder_hidden_states is None:
            # Encoder KV already cached.
            k = None
            v = None
        else:
            # Prefill phase, encoder KV cached here.
            kv_enc, _ = self.kv_proj_encoder(encoder_hidden_states)
            # Split kv in half
            k, v = kv_enc.split(self.kv_size, dim=-1)
        return q, k, v

1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
    def weight_loader_v1(self,
                         param: torch.nn.Parameter,
                         loaded_weight: torch.Tensor,
                         loaded_shard_id: Optional[str] = None):
        # just like all other parameters, does not yet
        # support loading bias with weight_loader_v2
        layer = (self.q_proj_decoder
                 if loaded_shard_id == "q" else self.kv_proj_encoder)
        target_param = self.select_proj_params(layer, param)
        shard_id_args = (loaded_shard_id, ) if loaded_shard_id != "q" else ()
        layer.weight_loader(target_param, loaded_weight, *shard_id_args)

1748
1749
1750
1751
1752
1753
1754
1755
    def weight_loader(self,
                      param: torch.nn.Parameter,
                      loaded_weight: torch.Tensor,
                      loaded_shard_id: Optional[str] = None):
        layer = (self.q_proj_decoder
                 if loaded_shard_id == "q" else self.kv_proj_encoder)
        target_param = self.select_proj_params(layer, param)
        shard_id_args = (loaded_shard_id, ) if loaded_shard_id != "q" else ()
1756
1757
1758
1759
        if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED:
            layer.weight_loader_v2(target_param, loaded_weight, *shard_id_args)
        else:
            layer.weight_loader(target_param, loaded_weight, *shard_id_args)
1760
1761
1762

    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
1763
        s += f", q_size={self.q_size}"
1764
1765
1766
1767
        s += f", kv_size={self.kv_size}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={get_tensor_model_parallel_world_size()}"
        s += ", gather_output=False"
1768
        return s