linear.py 78.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
from abc import abstractmethod
6
from typing import Any, Literal, Optional, Union, Tuple
7
import vllm.envs as envs
8
import torch
9
import torch.nn as nn
10
from torch.nn.parameter import Parameter, UninitializedParameter
11

12
from vllm import envs
13
14
15
16
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size,
                              split_tensor_along_last_dim,
                              tensor_model_parallel_all_gather,
17
18
                              tensor_model_parallel_all_reduce,
                              tensor_model_parallel_all_reduce_crp_m32)
19
from vllm.logger import init_logger
20
21
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig, QuantizeMethodBase)
22
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
23
# yapf: disable
24
from vllm.model_executor.parameter import (BasevLLMParameter,
25
                                           BlockQuantScaleParameter,
26
                                           PackedColumnParameter,
27
                                           PackedvLLMParameter,
28
29
                                           PerTensorScaleParameter,
                                           RowvLLMParameter)
30
# yapf: enable
31
from vllm.model_executor.utils import set_weight_attrs
32
from vllm.platforms import current_platform
gaoqiong's avatar
gaoqiong committed
33

zhuwenwen's avatar
zhuwenwen committed
34
import os
35
from vllm.model_executor.utils import gemm_bank_conf
36
37
from lmslim.quantize.quant_ops import lm_faster_rmsquant
from lmslim.quantize.quant_ops import lm_fuse_silu_mul_quant
38
        
39
40
logger = init_logger(__name__)

41
WEIGHT_LOADER_V2_SUPPORTED = [
42
    "CompressedTensorsLinearMethod",
43
44
    "BitBLASLinearMethod",
    "GPTQBitBLASLinearMethod",
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
    "AWQMarlinLinearMethod",
    "AWQLinearMethod",
    "GPTQMarlinLinearMethod",
    "Fp8LinearMethod",
    "MarlinLinearMethod",
    "QQQLinearMethod",
    "GPTQMarlin24LinearMethod",
    "TPUInt8LinearMethod",
    "GPTQLinearMethod",
    "FBGEMMFp8LinearMethod",
    "ModelOptFp8LinearMethod",
    "IPEXAWQLinearMethod",
    "IPEXGPTQLinearMethod",
    "HQQMarlinMethod",
    "QuarkLinearMethod",
    "ModelOptNvFp4LinearMethod",
zhuwenwen's avatar
zhuwenwen committed
61
    "BlockInt8LinearMethod",
62
]
63

64

65
66
67
68
69
70
71
72
73
def adjust_bitblas_shard(param, shard_size, shard_offset):
    bitblas_tile_size = getattr(param, "bitblas_tile_size", None)
    if bitblas_tile_size is not None:
        return (shard_size // bitblas_tile_size,
                shard_offset // bitblas_tile_size)

    return shard_size, shard_offset


74
75
76
77
78
79
80
81
def adjust_marlin_shard(param, shard_size, shard_offset):
    marlin_tile_size = getattr(param, "marlin_tile_size", None)
    if marlin_tile_size is None:
        return shard_size, shard_offset

    return shard_size * marlin_tile_size, shard_offset * marlin_tile_size


82
def adjust_bitsandbytes_4bit_shard(param: Parameter,
83
84
                                   shard_offsets: dict[str, tuple[int, int]],
                                   loaded_shard_id: str) -> tuple[int, int]:
85
86
    """Adjust the quantization offsets and sizes for BitsAndBytes sharding."""

87
88
    total, _ = shard_offsets["total"]
    orig_offset, orig_size = shard_offsets[loaded_shard_id]
89
90
91
92
93
94
95
96

    quantized_total = param.data.shape[0]
    quantized_offset = orig_offset * quantized_total // total
    quantized_size = orig_size * quantized_total // total

    return quantized_size, quantized_offset


97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
    """For fused modules (QKV and MLP) we have an array of length
    N that holds 1 scale for each "logical" matrix. So the param
    is an array of length N. The loaded_weight corresponds to 
    one of the shards on disk. Here, we slice the param based on 
    the shard_id for loading.
    """
    qkv_idxs = {"q": 0, "k": 1, "v": 2}

    if isinstance(shard_id, str):
        shard_id = qkv_idxs[shard_id]
    elif not isinstance(shard_id, int):
        raise ValueError(f"Unknown Shard Id {shard_id}")

    # AutoFP8 scales do not have a shape
    # compressed-tensors scales do have a shape
    if len(loaded_weight.shape) != 0:
        assert loaded_weight.shape[0] == 1
        loaded_weight = loaded_weight[0]

117
118
119
120
    if envs.VLLM_USE_NN:
        return param[shard_id], loaded_weight.t()
    else:
        return param[shard_id], loaded_weight
121
122


123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
# TODO(Isotr0py): We might need a more flexible structure to handle
# bitsandbytes shard offsets.
def left_shift_bitsandbytes_4bit_shard(bnb_weight_attrs: dict[str, Any]):
    """
    Separate the BitsAndBytes 4-bit shard.

    For example, given bnb weight attributes as below:
    {
        'bnb_shard_offsets': array([0, 4, 8, 16]), 
        'bnb_quant_state': {0: ..., 1: ..., 2: ...},
    }

    The function will return:
    {
        'bnb_shard_offsets': array([0, 4]), 
        'bnb_quant_state': {0: ...},
    }
    and
    {
        'bnb_shard_offsets': array([0, 4, 12]),
        'bnb_quant_state': {0: ..., 1: ...},
    }
    """
    shard_offsets = bnb_weight_attrs["bnb_shard_offsets"]
    offset_l = shard_offsets[:2]
    offset_r = shard_offsets[1:] - shard_offsets[1]
    quant_state_l = {0: bnb_weight_attrs["bnb_quant_state"][0]}
    quant_state_r = {
        i - 1: bnb_weight_attrs["bnb_quant_state"][i]
        for i in range(1,
                       len(shard_offsets) - 1)
    }
    left = dict(bnb_shard_offsets=offset_l, bnb_quant_state=quant_state_l)
    right = dict(bnb_shard_offsets=offset_r, bnb_quant_state=quant_state_r)
    return left, right


160
class LinearMethodBase(QuantizeMethodBase):
161
162
163
    """Base class for different (maybe quantized) linear methods."""

    @abstractmethod
164
165
    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
166
                       output_partition_sizes: list[int], input_size: int,
167
168
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
169
170
        """Create weights for a linear layer. 
           The weights will be set as attributes of the layer.
171

172
173
174
175
176
177
178
179
180
181
        Args:
            layer: The layer that is using the LinearMethodBase factory.
            input_size_per_partition: Size of the weight input dim on rank X.
            output_partition_sizes: Sizes of the output dim of each logical 
                weight on rank X. E.g., output_partition_sizes for QKVLinear
                is a list contains the width of Wq, Wk, Wv on rank X.
            input_size: Size of the input dim of the weight across all ranks.
            output_size: Size of the output dim of the weight across all ranks.
            params_dtype: Datatype of the parameters.
        """
182
183
184
        raise NotImplementedError

    @abstractmethod
185
186
187
188
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
189
190
        """Apply the weights in layer to the input tensor.
        Expects create_weights to have been called before on the layer."""
191
192
193
194
        raise NotImplementedError


class UnquantizedLinearMethod(LinearMethodBase):
195
    """Linear method without quantization."""
196
197
    
    def __init__(self):
zhuwenwen's avatar
zhuwenwen committed
198
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
199
        self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
gaoqiong's avatar
gaoqiong committed
200
        
201
202
    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
203
                       output_partition_sizes: list[int], input_size: int,
204
205
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
206
207
208
        if envs.VLLM_USE_NN:
            weight = Parameter(torch.empty(input_size_per_partition,
                                       sum(output_partition_sizes),
209
210
                                       dtype=params_dtype),
                           requires_grad=False)
211
212
213
214
215
        else:
            weight = Parameter(torch.empty(sum(output_partition_sizes),
                                           input_size_per_partition,
                                           dtype=params_dtype),
                           requires_grad=False)
216
        set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
217
218
        layer.register_parameter("weight", weight)
        set_weight_attrs(weight, extra_weight_attrs)
219

220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        if current_platform.is_cpu() and envs.VLLM_CPU_SGL_KERNEL:
            N, K = layer.weight.size()
            dtype = layer.weight.dtype
            if (torch._C._cpu._is_amx_tile_supported()
                    and dtype == torch.bfloat16 and N % 32 == 0
                    and K % 32 == 0):
                packed_weight = torch.ops._C.convert_weight_packed(
                    layer.weight)
                assert packed_weight.size() == layer.weight.size()
                layer.weight.copy_(packed_weight)
                if layer.bias is not None:
                    layer.bias = Parameter(layer.bias.to(torch.float32),
                                           requires_grad=False)
                layer.use_cpu_sgl = True
            else:
                logger.warning(
                    "CPU SGL kernels require Intel AMX support,"
                    " bfloat16 weight, IC and OC are divisible by 32.")
                layer.use_cpu_sgl = False

241
242
243
244
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
zhuwenwen's avatar
zhuwenwen committed
245
        if self.use_llama_nn:
246
247
            if gemm_bank_conf(layer.weight.shape[1] - 32) and os.environ['GEMM_PAD'] == '1':
                layer.weight = layer.weight[:,:-32]
248
                
zhuwenwen's avatar
zhuwenwen committed
249
            if bias is not None:
zhuwenwen's avatar
zhuwenwen committed
250
                if len(x.shape) == 2: 
251
                    return torch.addmm(bias, x, layer.weight)
zhuwenwen's avatar
zhuwenwen committed
252
                else:
253
                    return torch.matmul(x, layer.weight) + bias
zhuwenwen's avatar
zhuwenwen committed
254
            else:
255
                return torch.matmul(x, layer.weight)
zhuwenwen's avatar
zhuwenwen committed
256
        else:
257
258
259
260
            if envs.VLLM_USE_NN and x.shape[-1] == layer.weight.shape[0]:
                return dispatch_unquantized_gemm()(x, layer.weight.t(), bias)
            else:
                return dispatch_unquantized_gemm()(x, layer.weight, bias)
261

262

263
264
class LinearBase(torch.nn.Module):
    """Base linear layer.
265
266
267
268
269
270
271

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        bias: If true, add bias.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
272
        quant_config: Quantization configure.
273
        return_bias: If true, return bias together with outputs in forward pass.
274
275
276
277
278
279
280
281
    """

    def __init__(
        self,
        input_size: int,
        output_size: int,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
282
        quant_config: Optional[QuantizationConfig] = None,
283
        prefix: str = "",
284
285
        *,
        return_bias: bool = True,
286
287
288
289
290
291
292
293
294
295
    ):
        super().__init__()

        # Keep input parameters
        self.input_size = input_size
        self.output_size = output_size
        self.skip_bias_add = skip_bias_add
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype
296
        if quant_config is None:
297
298
            self.quant_method: Optional[
                QuantizeMethodBase] = UnquantizedLinearMethod()
299
        else:
300
301
            self.quant_method = quant_config.get_quant_method(self,
                                                              prefix=prefix)
302
        self.return_bias = return_bias
303

304
305
306
    def forward(
        self, x: torch.Tensor
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
307
308
309
310
311
312
313
314
315
316
317
318
319
        raise NotImplementedError


class ReplicatedLinear(LinearBase):
    """Replicated linear layer.

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        bias: If true, add bias.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
        quant_config: Quantization configure.
320
321
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
322
        return_bias: If true, return bias together with outputs in forward pass.
323
324
    """

325
326
327
328
329
330
331
332
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
333
        eps: Optional[float] = 1e-6,
334
335
336
337
        prefix: str = "",
        *,
        return_bias: bool = True,
    ):
338
339
340
341
342
        super().__init__(input_size,
                         output_size,
                         skip_bias_add,
                         params_dtype,
                         quant_config,
343
344
                         prefix=prefix,
                         return_bias=return_bias)
345
        self.eps = eps
346

347
348
        # All the linear layer supports quant method.
        assert self.quant_method is not None
349
350
351
352
353
        self.quant_method.create_weights(self,
                                         self.input_size, [self.output_size],
                                         self.input_size,
                                         self.output_size,
                                         self.params_dtype,
354
                                         weight_loader=self.weight_loader)
355

356
357
        if bias:
            self.bias = Parameter(
358
                torch.empty(self.output_size, dtype=self.params_dtype))
359
360
361
362
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
363
364
365
        else:
            self.register_parameter("bias", None)

366
367
368
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        # If the weight on disk does not have a shape, give it one
        # (such scales for AutoFp8).
369
370
371
372
373
374
375
376
377
378
379
        # Special case for GGUF

        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
            param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)

380
381
382
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)

383
384
385
386
        is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
        if envs.VLLM_USE_NN and not is_quantization:
            loaded_weight = loaded_weight.t()
            
387
388
389
        assert param.size() == loaded_weight.size(), (
            f"Tried to load weights of size {loaded_weight.size()}"
            f"to a parameter of size {param.size()}")
390
391
        param.data.copy_(loaded_weight)

392
    def forward(
393
394
395
396
397
398
        self, 
        input_: torch.Tensor,
        rms_weight: Optional[torch.Tensor] = None,
        residual: Optional[torch.Tensor] = None,
        quant_args: Optional[list] = None,
        update_hd: Optional[bool] = True
399
400
401
    ) -> Union[torch.Tensor, 
               tuple[torch.Tensor, Optional[Parameter]],
               tuple[torch.Tensor, torch.Tensor, Optional[Parameter], list[torch.Tensor]]]:
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
        if envs.USE_FUSED_RMS_QUANT and (rms_weight is not None or quant_args is not None):
            if quant_args is not None:
                input_quant_args = quant_args
            
                bias = self.bias if not self.skip_bias_add else None
                assert self.quant_method is not None
                output = self.quant_method.apply(self, input_, bias, input_quant_args)
                output_bias = self.bias if self.skip_bias_add else None
                if not self.return_bias:
                    return output
                return output, output_bias

            else:
                i_q, _scales = lm_faster_rmsquant(input=input_,
                                                  rms_weight=rms_weight,
                                                  epsilon=self.eps,
                                                  quant_dtype=torch.int8,
                                                  residual=residual,
                                                  update_input=update_hd
                                                  )
            
                new_residual = residual
                input_quant_args = [i_q, _scales]
                
                bias = self.bias if not self.skip_bias_add else None
                assert self.quant_method is not None
                output = self.quant_method.apply(self, input_, bias, input_quant_args)
                output_bias = self.bias if self.skip_bias_add else None
                if not self.return_bias:
                    return output
                return output, new_residual, output_bias, input_quant_args

        else:
            bias = self.bias if not self.skip_bias_add else None
            assert self.quant_method is not None
            output = self.quant_method.apply(self, input_, bias)
            output_bias = self.bias if self.skip_bias_add else None
            if not self.return_bias:
                return output
            return output, output_bias
442

443
444
445
446
447
448
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        return s

449

450
class ColumnParallelLinear(LinearBase):
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
    """Linear layer with column parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its second dimension as A = [A_1, ..., A_p].

    Args:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y_i = XA_i
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
467
        quant_config: Quantization configure.
James Fleming's avatar
James Fleming committed
468
469
        output_sizes: list of output sizes packed into one output, like for QKV
                       the list would be size 3.
470
471
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj) 
472
473
    """

474
475
476
477
478
479
480
481
482
483
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
        output_sizes: Optional[list[int]] = None,
484
        eps: Optional[float] = 1e-6,
485
486
487
488
        prefix: str = "",
        *,
        return_bias: bool = True,
    ):
489
        # Divide the weight matrix along the last dimension.
490
491
492
        self.tp_size = get_tensor_model_parallel_world_size()
        self.input_size_per_partition = input_size
        self.output_size_per_partition = divide(output_size, self.tp_size)
493
494
495
496
        self.output_partition_sizes = [self.output_size_per_partition]
        # If QKV or MergedColumn, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = [
497
                divide(output_size, self.tp_size)
498
499
500
                for output_size in self.output_sizes
            ]

501
502
503
504
505
506
507
        super().__init__(input_size,
                         output_size,
                         skip_bias_add,
                         params_dtype,
                         quant_config,
                         prefix,
                         return_bias=return_bias)
508
        self.eps = eps
509
510
        self.gather_output = gather_output

James Fleming's avatar
James Fleming committed
511
512
        if output_sizes is None:
            output_sizes = [output_size]
513

514
        assert self.quant_method is not None
515
516
        self.quant_method.create_weights(
            layer=self,
517
            input_size_per_partition=self.input_size_per_partition,
518
519
520
521
            output_partition_sizes=self.output_partition_sizes,
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
522
523
            weight_loader=(
                self.weight_loader_v2 if self.quant_method.__class__.__name__
524
                in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
525
526
527
528
529
530
531
532
533
534
535
536
537
538
        if bias:
            self.bias = Parameter(
                torch.empty(self.output_size_per_partition,
                            dtype=params_dtype))
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
        else:
            self.register_parameter("bias", None)

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        tp_rank = get_tensor_model_parallel_rank()
        output_dim = getattr(param, "output_dim", None)
539

540
541
542
543
544
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
545
        is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
546

547
548
549
550
551
552
553
554
        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
555
556
557
558
559
560
            final_shape = list(loaded_weight.shape)
            if output_dim is not None:
                tp_size = get_tensor_model_parallel_world_size()
                assert final_shape[output_dim] % tp_size == 0
                final_shape[output_dim] = final_shape[output_dim] // tp_size
            param.materialize(final_shape, dtype=loaded_weight.dtype)
561

562
        param_data = param.data
563
        if output_dim is not None and not is_sharded_weight:
564
565
566
567
            if not envs.VLLM_USE_NN or len(param_data.shape)==1 or is_quantization:
                shard_size = param_data.shape[output_dim] 
            else:
                shard_size = param_data.shape[int(not(output_dim))]
568
569
570
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                 shard_size)
571
572
573
574
575

        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)
576

577
578
579
        if envs.VLLM_USE_NN and not is_quantization:
            loaded_weight = loaded_weight.t()
            
580
581
582
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

583
    def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor):
584
585
586
587
588
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)
589
590
        param.load_column_parallel_weight(loaded_weight=loaded_weight)

591
    def forward(
592
593
594
595
        self, input_,
        rms_weight: Optional[torch.Tensor] = None,
        residual: Optional[torch.Tensor] = None,
        update_hd: Optional[bool] = True
596
597
598
    ) -> Union[torch.Tensor, 
               tuple[torch.Tensor, Optional[Parameter]],
               tuple[torch.Tensor, torch.Tensor, Optional[Parameter]]]:
599
600
601
602
603
604
605
606
607
608
609
610
611
        if envs.USE_FUSED_RMS_QUANT and rms_weight is not None:
            input_quant_args = None
            assert rms_weight is not None 
            i_q, _scales = lm_faster_rmsquant(input=input_,
                                        rms_weight=rms_weight,
                                        epsilon=self.eps,
                                        quant_dtype=torch.int8,
                                        residual=residual,
                                        update_input=update_hd)
            new_residual = residual
            input_quant_args = [i_q, _scales]
        
            bias = self.bias if not self.skip_bias_add else None
612

613
614
615
616
617
618
619
620
621
622
            assert self.quant_method is not None
            output_parallel = self.quant_method.apply(self, input_, bias, input_quant_args)
            if self.gather_output:
                output = tensor_model_parallel_all_gather(output_parallel)
            else:
                output = output_parallel
            output_bias = self.bias if self.skip_bias_add else None
            if not self.return_bias:
                return output
            return output, new_residual, output_bias
623
        else:
624
625
626
627
628
629
630
631
632
633
634
635
636
            bias = self.bias if not self.skip_bias_add else None
            # Matrix multiply.
            assert self.quant_method is not None
            output_parallel = self.quant_method.apply(self, input_, bias)
            if self.gather_output:
                # All-gather across the partitions.
                output = tensor_model_parallel_all_gather(output_parallel)
            else:
                output = output_parallel
            output_bias = self.bias if self.skip_bias_add else None
            if not self.return_bias:
                return output
            return output, output_bias
637

638
639
640
641
642
643
644
645
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size_per_partition}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={get_tensor_model_parallel_world_size()}"
        s += f", gather_output={self.gather_output}"
        return s

646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664

class MergedColumnParallelLinear(ColumnParallelLinear):
    """Packed linear layers with column parallelism.

    Similar to ColumnParallelLinear, but the weight matrix is concatenated
    along the output dimension. When the weight matrix is loaded, the
    different partitions are sharded separately.

    Args:
        input_size: input dimension of the linear layer.
        output_sizes: list of output dimensions of the linear layer.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make the output
                       available to all GPUs, otherwise, every GPU will have
                       its own output.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
665
        quant_config: Quantization configure.
666
667
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
668
        return_bias: If true, return bias together with outputs in forward pass.
669
670
    """

671
672
673
674
    def forward(
        self, input_,
        rms_weight: Optional[torch.Tensor] = None,
        residual: Optional[torch.Tensor] = None,
675
676
        update_hd: Optional[bool] = True,
        xqxs: Optional[tuple] = None
677
678
679
680
    ) -> Union[torch.Tensor, 
               tuple[torch.Tensor, Optional[Parameter]],
               tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[Parameter]],
               ]:
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
        if envs.USE_FUSED_RMS_QUANT and rms_weight is not None:
            input_quant_args = None
            assert residual is not None and rms_weight is not None 
            i_q, _scales = lm_faster_rmsquant(input=input_,
                                        rms_weight=rms_weight,
                                        epsilon=self.eps,
                                        quant_dtype=torch.int8,
                                        residual=residual,
                                        update_input=update_hd)
            
            new_residual = residual
            input_quant_args = [i_q, _scales]
            
            
            bias = self.bias if not self.skip_bias_add else None
            assert self.quant_method is not None
            output_parallel = self.quant_method.apply(self, input_, bias, input_quant_args)
            
            if self.gather_output:
                # All-gather across the partitions.
                output = tensor_model_parallel_all_gather(output_parallel)
            else:
                output = output_parallel
            output_bias = self.bias if self.skip_bias_add else None
            if not self.return_bias:
                return output
707
708
            return output, new_residual, i_q, _scales, output_bias
        
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
        elif envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and xqxs is not None:
            bias = self.bias if not self.skip_bias_add else None

            assert self.quant_method is not None
            output_parallel = self.quant_method.apply(self, input_, bias, input_quant_args=xqxs)
            if self.gather_output:
                # All-gather across the partitions.
                output = tensor_model_parallel_all_gather(output_parallel)
            else:
                output = output_parallel
            output_bias = self.bias if self.skip_bias_add else None
            if not self.return_bias:
                return output
            return output, output_bias
            
        else:
725
726
727
728
729
730
731
732
733
734
735
736
737
738
            bias = self.bias if not self.skip_bias_add else None

            assert self.quant_method is not None
            output_parallel = self.quant_method.apply(self, input_, bias)
            if self.gather_output:
                # All-gather across the partitions.
                output = tensor_model_parallel_all_gather(output_parallel)
            else:
                output = output_parallel
            output_bias = self.bias if self.skip_bias_add else None
            if not self.return_bias:
                return output
            return output, output_bias
    
739
740
741
742
743
744
745
746
747
    def __init__(
        self,
        input_size: int,
        output_sizes: list[int],
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
748
        eps: Optional[float] = 1e-6,
749
750
751
752
        prefix: str = "",
        *,
        return_bias: bool = True,
    ):
753
        self.eps = eps
754
755
756
        self.output_sizes = output_sizes
        tp_size = get_tensor_model_parallel_world_size()
        assert all(output_size % tp_size == 0 for output_size in output_sizes)
757
758
759
760
761
762
        super().__init__(input_size=input_size,
                         output_size=sum(output_sizes),
                         bias=bias,
                         gather_output=gather_output,
                         skip_bias_add=skip_bias_add,
                         params_dtype=params_dtype,
763
                         quant_config=quant_config,
764
765
                         prefix=prefix,
                         return_bias=return_bias)
766
767
768
769
770

    def weight_loader(self,
                      param: Parameter,
                      loaded_weight: torch.Tensor,
                      loaded_shard_id: Optional[int] = None):
James Fleming's avatar
James Fleming committed
771

772
773
774
775
776
        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
777
778
779
780
781
782
783
784
            if loaded_shard_id is not None:
                param.data[loaded_shard_id].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
                param.shard_weight_type = {
                    i: loaded_weight.item()
                    for i, _ in enumerate(self.output_sizes)
                }
785
786
            return

787
788
789
        if is_gguf_weight:
            tp_size = get_tensor_model_parallel_world_size()
            tp_rank = get_tensor_model_parallel_rank()
790

791
792
793
            output_dim = getattr(param, "output_dim", None)
            shard_size = loaded_weight.size(output_dim) // tp_size
            start_idx = tp_rank * shard_size
794

795
796
797
798
799
800
801
            if loaded_shard_id is not None:
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
802

803
804
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
805
        # Special case for AQLM codebooks.
James Fleming's avatar
James Fleming committed
806
        is_metadata = getattr(param, "is_metadata", False)
807
808
        # Special case for per-tensor scale to load scalar into fused array.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
809
        is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
810

811
        if loaded_shard_id is None:
812
813
            # Loaded weight is already fused on disk (mlp).
            # (e.g., Phi-3's gate_up_proj).
814
            if output_dim is None:
815
                if needs_scalar_to_array:
816
817
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
                        param_data, loaded_weight, 0)
818

819
820
821
822
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            current_shard_offset = 0
823
824
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)
825
            shard_offsets: list[tuple[int, int, int]] = []
826
827
828
829
830
            for i, output_size in enumerate(self.output_sizes):
                shard_offsets.append((i, current_shard_offset, output_size))
                current_shard_offset += output_size
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
831
                # Special case for Quantization.
832
833
834
835
836
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
                    shard_size = shard_size // param.pack_factor
                    shard_offset = shard_offset // param.pack_factor
837
                    # Special case for Marlin.
838
839
840
                    shard_size, shard_offset = adjust_marlin_shard(
                        param, shard_size, shard_offset)

841
842
843
                shard_size, shard_offset = adjust_bitblas_shard(
                    param, shard_size, shard_offset)

844
                if use_bitsandbytes_4bit:
845
846
847
848
849
850
851
852
853
                    index = list(itertools.accumulate([0] + self.output_sizes))
                    orig_offsets = {
                        str(i): (index[i], size)
                        for i, size in enumerate(self.output_sizes)
                    }
                    orig_offsets["total"] = (self.output_size, 0)
                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
                        param, orig_offsets, str(shard_id))

854
855
856
857
858
859
860
861
862
863
864
                loaded_weight_shard = loaded_weight.narrow(
                    output_dim, shard_offset, shard_size)
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id < len(self.output_sizes)
        tp_rank = get_tensor_model_parallel_rank()
        tp_size = get_tensor_model_parallel_world_size()
        if output_dim is not None:
            shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
            shard_size = self.output_sizes[loaded_shard_id] // tp_size
865
            # Special case for quantization.
866
867
868
869
870
871
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
                shard_size = shard_size // param.pack_factor
                shard_offset = shard_offset // param.pack_factor
872
                # Special case for Marlin.
873
874
                shard_size, shard_offset = adjust_marlin_shard(
                    param, shard_size, shard_offset)
875
876
            shard_size, shard_offset = adjust_bitblas_shard(
                param, shard_size, shard_offset)
gaoqiong's avatar
gaoqiong committed
877

878
879
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)
880
881
882
883
884
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

885
            if use_bitsandbytes_4bit:
886
887
888
                shard_size = loaded_weight.shape[output_dim]
                shard_offset = loaded_weight.shape[output_dim] * \
                    loaded_shard_id
889
890
891
892
893
                    
            if not envs.VLLM_USE_NN or is_quantization:
                param_data = param_data.narrow(output_dim, shard_offset, shard_size)
            else:
                param_data = param_data.narrow(int(not(output_dim)), shard_offset, shard_size)
894

895
            start_idx = tp_rank * shard_size
896
            if not is_sharded_weight:
897
898
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)
899
        # Special case for AQLM codebooks.
James Fleming's avatar
James Fleming committed
900
901
902
903
904
        elif is_metadata:
            # metadata indicates fixed size concatenated along dim 0
            shard_size = loaded_weight.shape[0]
            shard_offset = loaded_shard_id * shard_size
            param_data = param_data.narrow(0, shard_offset, shard_size)
905

906
907
908
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
909
910
                param_data, loaded_weight, loaded_shard_id)

911
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
912
913
914
915
916
917
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "MergedColumnParallelLinear, assume the weight is "
                    "the same for all partitions.")
918

919
920
921
        if envs.VLLM_USE_NN and not is_quantization:
            loaded_weight = loaded_weight.t()
            
gaoqiong's avatar
gaoqiong committed
922
923
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)
924

925
926
927
928
929
930
931
932
933
934
935
936
937
    def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
                                           loaded_weight: torch.Tensor):
        """
        Handle special case for models where MLP layers are already
        fused on disk. In this case, we have no shard id. This function
        determmines the shard id by splitting these layers and then calls
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """

        current_shard_offset = 0
938
        shard_offsets: list[tuple[int, int, int]] = []
939
940
941
942
943
944
945
946
        for i, output_size in enumerate(self.output_sizes):
            shard_offsets.append((i, current_shard_offset, output_size))
            current_shard_offset += output_size

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
947
948
            if isinstance(param, (PackedColumnParameter, PackedvLLMParameter
                                  )) and param.packed_dim == param.output_dim:
949
950
                shard_size, shard_offset = \
                    param.adjust_shard_indexes_for_packing(
951
952
953
954
955
956
957
958
959
960
961
962
                    shard_size=shard_size, shard_offset=shard_offset)

            loaded_weight_shard = loaded_weight.narrow(param.output_dim,
                                                       shard_offset,
                                                       shard_size)
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

    def weight_loader_v2(self,
                         param: BasevLLMParameter,
                         loaded_weight: torch.Tensor,
                         loaded_shard_id: Optional[int] = None):
        if loaded_shard_id is None:
963
964
965
966
            if isinstance(param, PerTensorScaleParameter):
                param.load_merged_column_weight(loaded_weight=loaded_weight,
                                                shard_id=0)
                return
967
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
968
                param.load_merged_column_weight(loaded_weight=loaded_weight)
969
                return
970
            # TODO: @dsikka - move to parameter.py
971
972
973
974
975
976
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id < len(self.output_sizes)

        tp_size = get_tensor_model_parallel_world_size()
977
978
979
980

        if isinstance(param, BlockQuantScaleParameter):
            from vllm.model_executor.layers.quantization.fp8 import (
                Fp8LinearMethod, Fp8MoEMethod)
981
982
983
            
            from vllm.model_executor.layers.quantization.blockwise_int8 import (
                BlockInt8LinearMethod, BlockInt8MoEMethod)
984
985
            assert self.quant_method is not None
            assert isinstance(self.quant_method,
986
                              (Fp8LinearMethod, Fp8MoEMethod, BlockInt8LinearMethod, BlockInt8MoEMethod))
987
988
989
990
991
992
993
994
995
996
997
            weight_block_size = self.quant_method.quant_config.weight_block_size
            assert weight_block_size is not None
            block_n, _ = weight_block_size[0], weight_block_size[1]
            shard_offset = (
                (sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) //
                block_n) // tp_size
            shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) //
                          block_n // tp_size)
        else:
            shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
            shard_size = self.output_sizes[loaded_shard_id] // tp_size
998
999
1000
1001
1002
1003

        param.load_merged_column_weight(loaded_weight=loaded_weight,
                                        shard_id=loaded_shard_id,
                                        shard_offset=shard_offset,
                                        shard_size=shard_size)

1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
class QKVParallelLinear(ColumnParallelLinear):
    """Linear layers for the attention's QKV transformation.

    Linear layers for the linear transformation of the query, key, and value
    vectors in the attention layer. The weight matrix is concatenated along
    the output dimension. The layer is parallelized along the head dimension.
    When the number of key/value heads is smaller than the number of query
    heads (e.g., multi-query/grouped-query attention), the key/value head may
    be replicated while the query heads are partitioned.

    Args:
        hidden_size: input hidden state size of the transformer.
        head_size: size of each attention head.
        total_num_heads: total number of attention query heads.
        total_num_kv_heads: total number of attention key/value heads. If
                            None, assume total_num_kv_heads = total_num_heads.
        bias: If true, add bias.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
1025
        quant_config: Quantization configure.
1026
1027
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
1028
        return_bias: If true, return bias together with outputs in forward pass.
1029
1030
    """

1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
    def __init__(
        self,
        hidden_size: int,
        head_size: int,
        total_num_heads: int,
        total_num_kv_heads: Optional[int] = None,
        bias: bool = True,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        *,
        return_bias: bool = True,
    ):
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
        self.hidden_size = hidden_size
        self.head_size = head_size
        self.total_num_heads = total_num_heads
        if total_num_kv_heads is None:
            total_num_kv_heads = total_num_heads
        self.total_num_kv_heads = total_num_kv_heads
        # Divide the weight matrix along the last dimension.
        tp_size = get_tensor_model_parallel_world_size()
        self.num_heads = divide(self.total_num_heads, tp_size)
        if tp_size >= self.total_num_kv_heads:
            self.num_kv_heads = 1
            self.num_kv_head_replicas = divide(tp_size,
                                               self.total_num_kv_heads)
        else:
            self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
            self.num_kv_head_replicas = 1
        input_size = self.hidden_size
        output_size = (self.num_heads +
                       2 * self.num_kv_heads) * tp_size * self.head_size
1064
1065
1066
1067
        self.output_sizes = [
            self.num_heads * self.head_size * tp_size,  # q_proj
            self.num_kv_heads * self.head_size * tp_size,  # k_proj
            self.num_kv_heads * self.head_size * tp_size,  # v_proj 
James Fleming's avatar
James Fleming committed
1068
        ]
gaoqiong's avatar
gaoqiong committed
1069

1070
1071
1072
1073
1074
1075
        super().__init__(input_size=input_size,
                         output_size=output_size,
                         bias=bias,
                         gather_output=False,
                         skip_bias_add=skip_bias_add,
                         params_dtype=params_dtype,
1076
                         quant_config=quant_config,
1077
1078
                         prefix=prefix,
                         return_bias=return_bias)
1079

1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
    def _get_shard_offset_mapping(self, loaded_shard_id: str):
        shard_offset_mapping = {
            "q": 0,
            "k": self.num_heads * self.head_size,
            "v": (self.num_heads + self.num_kv_heads) * self.head_size,
            "total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size
        }
        return shard_offset_mapping.get(loaded_shard_id)

    def _get_shard_size_mapping(self, loaded_shard_id: str):
        shard_size_mapping = {
            "q": self.num_heads * self.head_size,
            "k": self.num_kv_heads * self.head_size,
            "v": self.num_kv_heads * self.head_size,
        }
        return shard_size_mapping.get(loaded_shard_id)

    def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
                                           loaded_weight: torch.Tensor):
        """
        Handle special case for models where QKV layers are already 
        fused on disk. In this case, we have no shard id. This function
        determmines the shard id by splitting these layers and then calls
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """
        shard_offsets = [
            # (shard_id, shard_offset, shard_size)
            ("q", 0, self.total_num_heads * self.head_size),
            ("k", self.total_num_heads * self.head_size,
             self.total_num_kv_heads * self.head_size),
            ("v",
             (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
             self.total_num_kv_heads * self.head_size),
        ]

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
1122
1123
            if isinstance(param, (PackedColumnParameter, PackedvLLMParameter
                                  )) and param.packed_dim == param.output_dim:
1124
1125
                shard_size, shard_offset = \
                    param.adjust_shard_indexes_for_packing(
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
                    shard_size=shard_size, shard_offset=shard_offset)

            loaded_weight_shard = loaded_weight.narrow(param.output_dim,
                                                       shard_offset,
                                                       shard_size)
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

    def weight_loader_v2(self,
                         param: BasevLLMParameter,
                         loaded_weight: torch.Tensor,
                         loaded_shard_id: Optional[str] = None):
        if loaded_shard_id is None:  # special case for certain models
1138
            if isinstance(param, PerTensorScaleParameter):
1139
                param.load_qkv_weight(loaded_weight=loaded_weight, shard_id=0)
1140
                return
1141
1142
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
                param.load_qkv_weight(loaded_weight=loaded_weight)
1143
                return
1144
            # TODO: @dsikka - move to parameter.py
1145
1146
1147
1148
1149
1150
1151
1152
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id in ["q", "k", "v"]

        shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
        shard_size = self._get_shard_size_mapping(loaded_shard_id)

1153
1154
1155
1156
1157
1158
1159
1160
1161
        # Note(simon): This is needed for Qwen3's fp8 quantization.
        if isinstance(param, BlockQuantScaleParameter):
            assert self.quant_method is not None
            assert hasattr(self.quant_method, "quant_config")
            weight_block_size = self.quant_method.quant_config.weight_block_size
            block_n, _ = weight_block_size[0], weight_block_size[1]
            shard_offset = (shard_offset + block_n - 1) // block_n
            shard_size = (shard_size + block_n - 1) // block_n

1162
1163
1164
1165
1166
1167
        param.load_qkv_weight(loaded_weight=loaded_weight,
                              num_heads=self.num_kv_head_replicas,
                              shard_id=loaded_shard_id,
                              shard_offset=shard_offset,
                              shard_size=shard_size)

1168
1169
1170
1171
    def weight_loader(self,
                      param: Parameter,
                      loaded_weight: torch.Tensor,
                      loaded_shard_id: Optional[str] = None):
1172
1173
1174
1175
1176

        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
1177
        if is_gguf_weight_type:
1178
            idx_map = {"q": 0, "k": 1, "v": 2}
1179
1180
1181
1182
1183
1184
1185
1186
            if loaded_shard_id is not None:
                param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
                param.shard_weight_type = {
                    k: loaded_weight.item()
                    for k in idx_map
                }
1187
1188
            return

1189
1190
1191
        if is_gguf_weight:
            tp_size = get_tensor_model_parallel_world_size()
            tp_rank = get_tensor_model_parallel_rank()
1192

1193
1194
1195
1196
            output_dim = getattr(param, "output_dim", None)
            shard_size = loaded_weight.size(output_dim) // tp_size
            start_idx = tp_rank * shard_size

1197
1198
1199
1200
1201
1202
1203
            if loaded_shard_id is not None:
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
1204

1205
1206
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
1207
        # Special case for AQLM codebooks.
James Fleming's avatar
James Fleming committed
1208
        is_metadata = getattr(param, "is_metadata", False)
1209

1210
1211
        # Special case for per-tensor scales in fused case.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
1212
        is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
1213

1214
        if loaded_shard_id is None:
1215
1216
            # Loaded weight is already fused on disk (qkv).
            # (e.g., Phi-3's qkv_proj).
1217
            if output_dim is None:
1218
                if needs_scalar_to_array:
1219
1220
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
                        param_data, loaded_weight, 0)
1221

1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            shard_offsets = [
                # (shard_id, shard_offset, shard_size)
                ("q", 0, self.total_num_heads * self.head_size),
                ("k", self.total_num_heads * self.head_size,
                 self.total_num_kv_heads * self.head_size),
                ("v", (self.total_num_heads + self.total_num_kv_heads) *
                 self.head_size, self.total_num_kv_heads * self.head_size),
            ]
1233
1234
1235
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)

1236
1237
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
1238
                # Special case for Quantized Weights.
1239
1240
1241
1242
1243
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
                    shard_size = shard_size // param.pack_factor
                    shard_offset = shard_offset // param.pack_factor
1244

1245
                    # Special case for Marlin.
1246
1247
1248
                    shard_size, shard_offset = adjust_marlin_shard(
                        param, shard_size, shard_offset)

1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
                if use_bitsandbytes_4bit:
                    orig_qkv_offsets = {
                        "q": (0, self.total_num_heads * self.head_size),
                        "k": (self.total_num_heads * self.head_size,
                              self.total_num_kv_heads * self.head_size),
                        "v":
                        ((self.total_num_heads + self.total_num_kv_heads) *
                         self.head_size,
                         self.total_num_kv_heads * self.head_size),
                        "total":
                        ((self.total_num_heads + 2 * self.total_num_kv_heads) *
                         self.head_size, 0)
                    }

                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
                        param, orig_qkv_offsets, shard_id)

1266
1267
1268
1269
1270
1271
1272
                loaded_weight_shard = loaded_weight.narrow(
                    output_dim, shard_offset, shard_size)
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        tp_rank = get_tensor_model_parallel_rank()
        assert loaded_shard_id in ["q", "k", "v"]
1273
1274

        # If output dim is defined, use the default loading process.
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
        if output_dim is not None:
            if loaded_shard_id == "q":
                shard_offset = 0
                shard_size = self.num_heads * self.head_size
            elif loaded_shard_id == "k":
                shard_offset = self.num_heads * self.head_size
                shard_size = self.num_kv_heads * self.head_size
            elif loaded_shard_id == "v":
                shard_offset = (self.num_heads +
                                self.num_kv_heads) * self.head_size
                shard_size = self.num_kv_heads * self.head_size
1286
            # Special case for Quantized Weights.
1287
1288
1289
1290
1291
1292
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
                shard_size = shard_size // param.pack_factor
                shard_offset = shard_offset // param.pack_factor
1293

1294
                # Special case for Marlin.
1295
1296
1297
                shard_size, shard_offset = adjust_marlin_shard(
                    param, shard_size, shard_offset)

1298
1299
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)
1300
1301
1302
1303
1304
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

1305
            if use_bitsandbytes_4bit:
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
                orig_qkv_offsets = {
                    "q": (0, self.num_heads * self.head_size),
                    "k": (self.num_heads * self.head_size,
                          self.num_kv_heads * self.head_size),
                    "v":
                    ((self.num_heads + self.num_kv_heads) * self.head_size,
                     self.num_kv_heads * self.head_size),
                    "total":
                    ((self.num_heads + 2 * self.num_kv_heads) * self.head_size,
                     0)
                }
1317
                shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
1318
                    param, orig_qkv_offsets, loaded_shard_id)
gaoqiong's avatar
gaoqiong committed
1319

1320
1321
1322
1323
1324
1325
1326
            if not envs.VLLM_USE_NN or len(param_data.shape)==1 or is_quantization:
                param_data = param_data.narrow(output_dim, shard_offset,
                                               shard_size)
            else:
                param_data = param_data.narrow(int(not(output_dim)), shard_offset,
                                               shard_size)
                
zhuwenwen's avatar
zhuwenwen committed
1327
            if loaded_shard_id == "q":
1328
1329
1330
                shard_id = tp_rank
            else:
                shard_id = tp_rank // self.num_kv_head_replicas
1331
            start_idx = shard_id * shard_size
1332

1333
            if not is_sharded_weight:
1334
1335
1336
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)

1337
        # Special case for for AQLM codebooks.
James Fleming's avatar
James Fleming committed
1338
1339
1340
1341
1342
1343
        elif is_metadata:
            # metadata indicates fixed size concatenated along dim 0
            shard_size = loaded_weight.shape[0]
            shard_index = ["q", "k", "v"].index(loaded_shard_id)
            param_data = param_data.narrow(0, shard_index * shard_size,
                                           shard_size)
1344
1345
1346
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
1347
                param_data, loaded_weight, loaded_shard_id)
1348
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
1349
1350
1351
1352
1353
1354
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "QKVParallelLinear, assume the weight is the same "
                    "for all partitions.")
gaoqiong's avatar
gaoqiong committed
1355

1356
1357
1358
        if envs.VLLM_USE_NN and not is_quantization:
            loaded_weight = loaded_weight.t()
            
gaoqiong's avatar
gaoqiong committed
1359
1360
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)
1361
1362


1363
class RowParallelLinear(LinearBase):
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
    """Linear layer with row parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its first dimension and X along its second dimension as:
               -   -
              | A_1 |
              | .   |
          A = | .   |        X = [X_1, ..., X_p]
              | .   |
              | A_p |
               -   -
    Arguments:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias. Note that bias is not parallelized.
        input_is_parallel: If true, we assume that the input is already
                           split across the GPUs and we do not split
                           again.
        skip_bias_add: This was added to enable performance optimization where
                       bias can be fused with other element-wise operations.
                       We skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
1386
1387
1388
        reduce_results: If true, call all-reduce on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y = X_iA_i
1389
        quant_config: Quantization configure.
1390
1391
1392
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.down_proj)
        return_bias: If true, return bias together with outputs in forward pass.
1393
1394
    """

1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        input_is_parallel: bool = True,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        reduce_results: bool = True,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        *,
        return_bias: bool = True,
    ):
1409
1410
1411
1412
1413
1414
1415
        # Divide the weight matrix along the first dimension.
        self.tp_rank = get_tensor_model_parallel_rank()
        self.tp_size = get_tensor_model_parallel_world_size()
        self.input_size_per_partition = divide(input_size, self.tp_size)
        self.output_size_per_partition = output_size
        self.output_partition_sizes = [output_size]

1416
1417
1418
1419
1420
1421
1422
        super().__init__(input_size,
                         output_size,
                         skip_bias_add,
                         params_dtype,
                         quant_config,
                         prefix,
                         return_bias=return_bias)
1423

1424
1425
1426
        self.input_is_parallel = input_is_parallel
        self.reduce_results = reduce_results

1427
        assert self.quant_method is not None
1428
1429
1430
        self.quant_method.create_weights(
            layer=self,
            input_size_per_partition=self.input_size_per_partition,
1431
            output_partition_sizes=self.output_partition_sizes,
1432
1433
1434
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
1435
1436
            weight_loader=(
                self.weight_loader_v2 if self.quant_method.__class__.__name__
1437
                in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
1438
1439
1440
1441
1442
1443
        if not reduce_results and (bias and not skip_bias_add):
            raise ValueError("When not reduce the results, adding bias to the "
                             "results can lead to incorrect results")

        if bias:
            self.bias = Parameter(
1444
                torch.empty(self.output_size, dtype=params_dtype))
1445
1446
1447
1448
1449
1450
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
        else:
            self.register_parameter("bias", None)
1451
        from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce
1452
        self.tbo_all_reduce = tbo_all_reduce
1453
1454
1455

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        tp_rank = get_tensor_model_parallel_rank()
1456
        tp_size = get_tensor_model_parallel_world_size()
1457
        input_dim = getattr(param, "input_dim", None)
1458
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1459
1460
1461
1462
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
1463
1464
1465
1466
1467
1468
1469
1470
1471

        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
1472
1473
1474
1475
            weight_shape = list(loaded_weight.shape)
            if input_dim:
                weight_shape[input_dim] = weight_shape[input_dim] // tp_size
            param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
1476
1477
            
        is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
1478

1479
        param_data = param.data
1480
        if input_dim is not None and not is_sharded_weight:
1481
1482
1483
1484
            if not envs.VLLM_USE_NN or is_quantization:
                shard_size = param_data.shape[input_dim]
            else:
                shard_size = param_data.shape[int(not(input_dim))]
1485
1486
1487
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(input_dim, start_idx,
                                                 shard_size)
1488

1489
1490
1491
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
1492
1493
            loaded_weight = loaded_weight.reshape(1)

1494
1495
1496
        if envs.VLLM_USE_NN and not is_quantization:
            loaded_weight = loaded_weight.t()
            
1497
1498
1499
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

1500
1501
    def weight_loader_v2(self, param: BasevLLMParameter,
                         loaded_weight: torch.Tensor):
1502
1503
1504
1505
1506
1507
1508

        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)

1509
1510
        param.load_row_parallel_weight(loaded_weight=loaded_weight)

1511
    def forward(
1512
        self, input_,
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
        use_fused_silu_mul_quant: Optional[bool] = False,
        pa_rms_weight: Optional[torch.Tensor] = None,
        pa_residual: Optional[torch.Tensor] = None,
        pa_rms_eps: Optional[float] = 1e-6,
        pa_quant_dtype: Optional[torch.dtype] = torch.int8,
        update_input: Optional[bool] = True
    ) -> Union[torch.Tensor, 
               tuple[torch.Tensor, Optional[Parameter]],
               tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[Parameter]]
               ]:
        if envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and pa_rms_weight is not None and pa_residual is not None:    
            if self.input_is_parallel:
                input_parallel = input_
            else:
                tp_rank = get_tensor_model_parallel_rank()
                splitted_input = split_tensor_along_last_dim(
                    input_, num_partitions=self.tp_size)
                input_parallel = splitted_input[tp_rank].contiguous()
1531

1532
1533
1534
1535
1536
            # Matrix multiply.
            assert self.quant_method is not None
            # Only fuse bias add into GEMM for rank 0 (this ensures that
            # bias will not get added more than once in TP>1 case)
            bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
            if use_fused_silu_mul_quant:
                xq, xs = lm_fuse_silu_mul_quant(input_parallel)
                
                silu_quant_args = [xq, xs]
                output_parallel = self.quant_method.apply(self,
                                                        input_parallel,
                                                        bias=bias_,
                                                        silu_quant_args=silu_quant_args)
            else:
                output_parallel = self.quant_method.apply(self,
                                                        input_parallel,
                                                        bias=bias_)
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
            if self.reduce_results and self.tp_size > 1:
                if envs.VLLM_ENABLE_TBO:
                    output = self.tbo_all_reduce(output_parallel)
                
                packages_ = tensor_model_parallel_all_reduce_crp_m32(output_parallel,
                                                        pa_rms_weight=pa_rms_weight,
                                                        pa_residual=pa_residual,
                                                        pa_rms_eps=pa_rms_eps,
                                                        pa_quant_dtype=pa_quant_dtype,
                                                        update_input=update_input)
                hs, resi, xq, xs = packages_
                output = hs
                    
1562
            else:
1563
1564
1565
1566
1567
1568
1569
1570
                output = output_parallel

            output_bias = self.bias if self.skip_bias_add else None

            if not self.return_bias:
                return output
            return output, resi, xq, xs, output_bias
            
1571
        else: # RQ and Defualt forward
1572
1573
1574
1575
1576
1577
1578
            if self.input_is_parallel:
                input_parallel = input_
            else:
                tp_rank = get_tensor_model_parallel_rank()
                splitted_input = split_tensor_along_last_dim(
                    input_, num_partitions=self.tp_size)
                input_parallel = splitted_input[tp_rank].contiguous()
1579

1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
            # Matrix multiply.
            assert self.quant_method is not None
            # Only fuse bias add into GEMM for rank 0 (this ensures that
            # bias will not get added more than once in TP>1 case)
            bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
            if use_fused_silu_mul_quant:
                xq, xs = lm_fuse_silu_mul_quant(input_parallel)
                
                silu_quant_args = [xq, xs]
                output_parallel = self.quant_method.apply(self,
                                                        input_parallel,
                                                        bias=bias_,
                                                        silu_quant_args=silu_quant_args)
            else:
                output_parallel = self.quant_method.apply(self,
                                                        input_parallel,
                                                        bias=bias_)
            if self.reduce_results and self.tp_size > 1:
                if envs.VLLM_ENABLE_TBO:
                    output = self.tbo_all_reduce(output_parallel)
                else:
                    output = tensor_model_parallel_all_reduce(output_parallel)
            else:
                output = output_parallel

            output_bias = self.bias if self.skip_bias_add else None
1606

1607
1608
1609
            if not self.return_bias:
                return output
            return output, output_bias
1610
1611
1612
1613
1614
1615
1616
1617

    def extra_repr(self) -> str:
        s = f"input_features={self.input_size_per_partition}"
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={self.tp_size}"
        s += f", reduce_results={self.reduce_results}"
        return s
1618
1619


1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
class QKVCrossParallelLinear(LinearBase):
    """Linear layers for efficient cross-attention's QKV transformation.

    Args:
        hidden_size: input hidden state size of the transformer.
        head_size: size of each attention head.
        total_num_heads: total number of attention query heads.
        total_num_kv_heads: total number of attention key/value heads. If
                            None, assume total_num_kv_heads = total_num_heads.
        bias: If true, add bias.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
        quant_config: Quantization configure.
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
    """
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648

    def __init__(self,
                 hidden_size: int,
                 head_size: int,
                 total_num_heads: int,
                 total_num_kv_heads: Optional[int] = None,
                 bias: bool = True,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
        # input_size and output_size are not used, just for alignment
        input_size = hidden_size
        output_size = (total_num_heads + (total_num_kv_heads or 0)) * head_size
        super().__init__(input_size=input_size,
                         output_size=output_size,
                         skip_bias_add=skip_bias_add,
                         params_dtype=params_dtype,
                         quant_config=quant_config,
                         prefix=prefix)

        self.quant_config = quant_config

1661
        # Empty placeholders for loading as a single module.
1662
1663
1664
1665
1666
1667
1668
1669
1670
        placeholder_size = 0
        assert self.quant_method is not None
        self.quant_method.create_weights(self,
                                         placeholder_size, [placeholder_size],
                                         placeholder_size,
                                         placeholder_size,
                                         self.params_dtype,
                                         weight_loader=self.weight_loader)

1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
        # Use a dictionary to avoid submodules parameters auto-registration:
        # drop-in replacement for a `QKVParallelLinear` module.
        self.proj = dict()
        self.proj["q_proj_decoder"] = ColumnParallelLinear(
            input_size=hidden_size,
            output_size=total_num_heads * head_size,
            bias=bias,
            quant_config=quant_config,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            prefix=f"{prefix}.q_proj_decoder")

        self.proj["kv_proj_encoder"] = QKVParallelLinear(
            hidden_size=hidden_size,
            head_size=head_size,
            total_num_heads=0,
            total_num_kv_heads=total_num_kv_heads,
            bias=bias,
            quant_config=quant_config,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            prefix=f"{prefix}.kv_proj_encoder")

        # `kv_proj_encoder.num_kv_heads` accounts for sharding with tp>1.
1695
        self.q_size = self.q_proj_decoder.output_size_per_partition
1696
1697
1698
1699
1700
        self.kv_size = self.kv_proj_encoder.num_kv_heads * head_size

        if bias:
            self.bias = torch.nn.Parameter()
            set_weight_attrs(self.bias, {
1701
1702
                "output_dim": 0,
                "weight_loader": self.weight_loader,
1703
            })
1704
1705
        else:
            self.bias = None
1706

1707
1708
1709
1710
1711
    def process_weights_after_loading(self):
        for layer in self.proj.values():
            if self.quant_method is not None:
                self.quant_method.process_weights_after_loading(layer)

1712
    @property
1713
1714
1715
    def q_proj_decoder(self) -> ColumnParallelLinear:
        layer = self.proj["q_proj_decoder"]
        for name, param in self.named_parameters():
1716
1717
1718
1719
1720
            target_param = getattr(layer, name, None)
            if target_param is not None:
                self.sync_weight_attrs(param,
                                       target_param,
                                       mode="q_proj_decoder")
1721
        return layer
1722
1723

    @property
1724
1725
1726
    def kv_proj_encoder(self) -> QKVParallelLinear:
        layer = self.proj["kv_proj_encoder"]
        for name, param in self.named_parameters():
1727
1728
1729
1730
1731
            target_param = getattr(layer, name, None)
            if target_param is not None:
                self.sync_weight_attrs(param,
                                       target_param,
                                       mode="kv_proj_encoder")
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
        return layer

    def sync_weight_attrs(
        self,
        src_param: nn.Parameter,
        tgt_param: nn.Parameter,
        mode: Literal["q_proj_decoder", "kv_proj_encoder"],
    ):
        missing_attrs_dict = {
            k: getattr(src_param, k)
1742
1743
            for k in (set(vars(src_param).keys()) -
                      set(vars(tgt_param).keys()))
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
        }
        # TODO(Isotr0py): handle bitsandbytes 8bit
        use_bitsandbytes_4bit = getattr(src_param, "use_bitsandbytes_4bit",
                                        False)
        if (missing_attrs_dict and use_bitsandbytes_4bit):
            q_proj_attrs, kv_proj_attrs = left_shift_bitsandbytes_4bit_shard(
                missing_attrs_dict)
            if mode == "q_proj_decoder":
                set_weight_attrs(tgt_param, q_proj_attrs)
            elif mode == "kv_proj_encoder":
                set_weight_attrs(tgt_param, kv_proj_attrs)
        else:
            set_weight_attrs(tgt_param, missing_attrs_dict)
1757

1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
    def _is_same_param(
        self,
        src_param: torch.nn.Parameter,
        map_param: torch.nn.Parameter,
    ) -> bool:
        """Check if two parameters are exactly pointing to same things."""
        # ignore weight_loader because it's always different
        key_to_ignore = ["weight_loader", "_weight_loader"]
        has_same_type_name = type(src_param) is type(map_param)
        src_param_attrs = {
            k: v
            for k, v in src_param.__dict__.items() if k not in key_to_ignore
        }
        map_param_attrs = {
            k: v
            for k, v in map_param.__dict__.items() if k not in key_to_ignore
        }
        has_same_attrs = src_param_attrs == map_param_attrs
        return has_same_type_name and has_same_attrs

    def select_proj_params(
        self,
        layer: nn.Module,
        param: nn.Parameter,
    ) -> nn.Parameter:
        """
        Given the placeholder param, 
        return the corresponding param in the proj layers.
        """
        target_param_list = [
            v for _, v in layer.named_parameters()
            if self._is_same_param(param, v)
        ]
        assert len(target_param_list) == 1
        target_param = target_param_list[0]
        return target_param

    def forward(  # type: ignore[override]
        self,
        decoder_hidden_states: torch.Tensor,
        encoder_hidden_states: torch.Tensor,
    ) -> tuple[torch.Tensor, ...]:
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
        q, _ = self.q_proj_decoder(decoder_hidden_states)
        if encoder_hidden_states is None:
            # Encoder KV already cached.
            k = None
            v = None
        else:
            # Prefill phase, encoder KV cached here.
            kv_enc, _ = self.kv_proj_encoder(encoder_hidden_states)
            # Split kv in half
            k, v = kv_enc.split(self.kv_size, dim=-1)
        return q, k, v

1812
1813
1814
1815
1816
1817
1818
1819
    def weight_loader(self,
                      param: torch.nn.Parameter,
                      loaded_weight: torch.Tensor,
                      loaded_shard_id: Optional[str] = None):
        layer = (self.q_proj_decoder
                 if loaded_shard_id == "q" else self.kv_proj_encoder)
        target_param = self.select_proj_params(layer, param)
        shard_id_args = (loaded_shard_id, ) if loaded_shard_id != "q" else ()
1820
1821
1822
1823
        if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED:
            layer.weight_loader_v2(target_param, loaded_weight, *shard_id_args)
        else:
            layer.weight_loader(target_param, loaded_weight, *shard_id_args)
1824
1825
1826

    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
1827
        s += f", q_size={self.q_size}"
1828
1829
1830
1831
        s += f", kv_size={self.kv_size}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={get_tensor_model_parallel_world_size()}"
        s += ", gather_output=False"
1832
        return s