"...src/kernels/kCalculateAmoebaCudaPmeFixedEField.cu" did not exist on "73b55e336e432b2a67a01052d2ab2bbf155d6272"
linear.py 66.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
from abc import abstractmethod
6
from typing import Any, Literal, Optional, Union
7
8

import torch
9
import torch.nn as nn
10
from torch.nn.parameter import Parameter, UninitializedParameter
11

12
13
14
15
16
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size,
                              split_tensor_along_last_dim,
                              tensor_model_parallel_all_gather,
                              tensor_model_parallel_all_reduce)
17
from vllm.logger import init_logger
18
from vllm.model_executor.custom_op import CustomOp
19
20
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig, QuantizeMethodBase)
21
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
22
# yapf: disable
23
from vllm.model_executor.parameter import (BasevLLMParameter,
24
                                           BlockQuantScaleParameter,
25
                                           PackedColumnParameter,
26
                                           PackedvLLMParameter,
27
28
                                           PerTensorScaleParameter,
                                           RowvLLMParameter)
29
# yapf: enable
30
from vllm.model_executor.utils import set_weight_attrs
31
from vllm.platforms import current_platform
32
33
34

logger = init_logger(__name__)

35
WEIGHT_LOADER_V2_SUPPORTED = [
36
    "CompressedTensorsLinearMethod",
37
    "CompressedTensorsLinearTransformMethod",
38
39
    "BitBLASLinearMethod",
    "GPTQBitBLASLinearMethod",
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
    "AWQMarlinLinearMethod",
    "AWQLinearMethod",
    "GPTQMarlinLinearMethod",
    "Fp8LinearMethod",
    "MarlinLinearMethod",
    "GPTQMarlin24LinearMethod",
    "TPUInt8LinearMethod",
    "GPTQLinearMethod",
    "FBGEMMFp8LinearMethod",
    "ModelOptFp8LinearMethod",
    "IPEXAWQLinearMethod",
    "IPEXGPTQLinearMethod",
    "HQQMarlinMethod",
    "QuarkLinearMethod",
    "ModelOptNvFp4LinearMethod",
55
    "PetitNvFp4LinearMethod",
56
]
57

58

59
60
61
62
63
64
65
66
67
def adjust_bitblas_shard(param, shard_size, shard_offset):
    bitblas_tile_size = getattr(param, "bitblas_tile_size", None)
    if bitblas_tile_size is not None:
        return (shard_size // bitblas_tile_size,
                shard_offset // bitblas_tile_size)

    return shard_size, shard_offset


68
69
70
71
72
73
74
75
def adjust_marlin_shard(param, shard_size, shard_offset):
    marlin_tile_size = getattr(param, "marlin_tile_size", None)
    if marlin_tile_size is None:
        return shard_size, shard_offset

    return shard_size * marlin_tile_size, shard_offset * marlin_tile_size


76
def adjust_bitsandbytes_4bit_shard(param: Parameter,
77
78
                                   shard_offsets: dict[str, tuple[int, int]],
                                   loaded_shard_id: str) -> tuple[int, int]:
79
80
    """Adjust the quantization offsets and sizes for BitsAndBytes sharding."""

81
82
    total, _ = shard_offsets["total"]
    orig_offset, orig_size = shard_offsets[loaded_shard_id]
83
84
85
86
87
88
89
90

    quantized_total = param.data.shape[0]
    quantized_offset = orig_offset * quantized_total // total
    quantized_size = orig_size * quantized_total // total

    return quantized_size, quantized_offset


91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
    """For fused modules (QKV and MLP) we have an array of length
    N that holds 1 scale for each "logical" matrix. So the param
    is an array of length N. The loaded_weight corresponds to 
    one of the shards on disk. Here, we slice the param based on 
    the shard_id for loading.
    """
    qkv_idxs = {"q": 0, "k": 1, "v": 2}

    if isinstance(shard_id, str):
        shard_id = qkv_idxs[shard_id]
    elif not isinstance(shard_id, int):
        raise ValueError(f"Unknown Shard Id {shard_id}")

    # AutoFP8 scales do not have a shape
    # compressed-tensors scales do have a shape
    if len(loaded_weight.shape) != 0:
        assert loaded_weight.shape[0] == 1
        loaded_weight = loaded_weight[0]

    return param[shard_id], loaded_weight


114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
# TODO(Isotr0py): We might need a more flexible structure to handle
# bitsandbytes shard offsets.
def left_shift_bitsandbytes_4bit_shard(bnb_weight_attrs: dict[str, Any]):
    """
    Separate the BitsAndBytes 4-bit shard.

    For example, given bnb weight attributes as below:
    {
        'bnb_shard_offsets': array([0, 4, 8, 16]), 
        'bnb_quant_state': {0: ..., 1: ..., 2: ...},
    }

    The function will return:
    {
        'bnb_shard_offsets': array([0, 4]), 
        'bnb_quant_state': {0: ...},
    }
    and
    {
        'bnb_shard_offsets': array([0, 4, 12]),
        'bnb_quant_state': {0: ..., 1: ...},
    }
    """
    shard_offsets = bnb_weight_attrs["bnb_shard_offsets"]
    offset_l = shard_offsets[:2]
    offset_r = shard_offsets[1:] - shard_offsets[1]
    quant_state_l = {0: bnb_weight_attrs["bnb_quant_state"][0]}
    quant_state_r = {
        i - 1: bnb_weight_attrs["bnb_quant_state"][i]
        for i in range(1,
                       len(shard_offsets) - 1)
    }
    left = dict(bnb_shard_offsets=offset_l, bnb_quant_state=quant_state_l)
    right = dict(bnb_shard_offsets=offset_r, bnb_quant_state=quant_state_r)
    return left, right


151
class LinearMethodBase(QuantizeMethodBase):
152
153
154
    """Base class for different (maybe quantized) linear methods."""

    @abstractmethod
155
156
    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
157
                       output_partition_sizes: list[int], input_size: int,
158
159
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
160
161
        """Create weights for a linear layer. 
           The weights will be set as attributes of the layer.
162

163
164
165
166
167
168
169
170
171
172
        Args:
            layer: The layer that is using the LinearMethodBase factory.
            input_size_per_partition: Size of the weight input dim on rank X.
            output_partition_sizes: Sizes of the output dim of each logical 
                weight on rank X. E.g., output_partition_sizes for QKVLinear
                is a list contains the width of Wq, Wk, Wv on rank X.
            input_size: Size of the input dim of the weight across all ranks.
            output_size: Size of the output dim of the weight across all ranks.
            params_dtype: Datatype of the parameters.
        """
173
174
175
        raise NotImplementedError

    @abstractmethod
176
177
178
179
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
180
181
        """Apply the weights in layer to the input tensor.
        Expects create_weights to have been called before on the layer."""
182
183
184
185
        raise NotImplementedError


class UnquantizedLinearMethod(LinearMethodBase):
186
    """Linear method without quantization."""
187

188
189
    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
190
                       output_partition_sizes: list[int], input_size: int,
191
192
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
193
        weight = Parameter(torch.empty(sum(output_partition_sizes),
CHU Tianxiang's avatar
CHU Tianxiang committed
194
                                       input_size_per_partition,
195
196
197
                                       dtype=params_dtype),
                           requires_grad=False)
        set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
198
199
        layer.register_parameter("weight", weight)
        set_weight_attrs(weight, extra_weight_attrs)
200

201
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
202
203
204
205
        if current_platform.is_cpu():
            from vllm.model_executor.layers.utils import (
                dispatch_cpu_unquantized_gemm)
            dispatch_cpu_unquantized_gemm(layer, remove_weight=True)
206

207
208
209
210
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
211

212
        return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
213
214


215
class LinearBase(CustomOp):
216
    """Base linear layer.
217
218
219
220
221
222

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
223
        quant_config: Quantization configure.
224
        prefix: Prefix for parameter names.
225
        return_bias: If true, return bias together with outputs in forward pass.
226
        disable_tp: If true, tensor parallelism will be disabled for this layer.
227
228
229
230
231
232
233
234
    """

    def __init__(
        self,
        input_size: int,
        output_size: int,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
235
        quant_config: Optional[QuantizationConfig] = None,
236
        prefix: str = "",
237
238
        *,
        return_bias: bool = True,
239
        disable_tp: bool = False,
240
241
242
243
244
245
246
247
248
249
    ):
        super().__init__()

        # Keep input parameters
        self.input_size = input_size
        self.output_size = output_size
        self.skip_bias_add = skip_bias_add
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype
250
251
        self.quant_config = quant_config
        self.prefix = prefix
252
        if quant_config is None:
253
254
            self.quant_method: Optional[
                QuantizeMethodBase] = UnquantizedLinearMethod()
255
        else:
256
257
            self.quant_method = quant_config.get_quant_method(self,
                                                              prefix=prefix)
258
        self.return_bias = return_bias
259
260
261
262
263
264
        self.disable_tp = disable_tp
        self.tp_rank = (get_tensor_model_parallel_rank()
                        if not disable_tp else 0)
        self.tp_size = (get_tensor_model_parallel_world_size()
                        if not disable_tp else 1)

265
    def update_param_tp_status(self):
266
267
268
269
        for param in self.parameters():
            if isinstance(param, BasevLLMParameter):
                param.tp_rank = self.tp_rank
                param.tp_size = self.tp_size
270
271


272
@CustomOp.register("replicated_linear")
273
274
275
276
277
278
279
280
281
282
class ReplicatedLinear(LinearBase):
    """Replicated linear layer.

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        bias: If true, add bias.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
        quant_config: Quantization configure.
283
284
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
285
        return_bias: If true, return bias together with outputs in forward pass.
286
        disable_tp: Take no effect for replicated linear layers.
287
288
    """

289
290
291
292
293
294
295
296
297
298
299
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        *,
        return_bias: bool = True,
300
        disable_tp: bool = False,
301
    ):
302
303
304
305
306
        super().__init__(input_size,
                         output_size,
                         skip_bias_add,
                         params_dtype,
                         quant_config,
307
                         prefix=prefix,
308
309
                         return_bias=return_bias,
                         disable_tp=disable_tp)
310

311
312
        # All the linear layer supports quant method.
        assert self.quant_method is not None
313
        self.quant_method.create_weights(self,
314
                                         self.input_size, [self.output_size],
315
316
317
                                         self.input_size,
                                         self.output_size,
                                         self.params_dtype,
318
                                         weight_loader=self.weight_loader)
319

320
321
        if bias:
            self.bias = Parameter(
322
                torch.empty(self.output_size, dtype=self.params_dtype))
323
324
325
326
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
327
328
329
        else:
            self.register_parameter("bias", None)

330
331
332
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        # If the weight on disk does not have a shape, give it one
        # (such scales for AutoFp8).
333
334
335
336
337
338
339
340
341
342
343
        # Special case for GGUF

        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
            param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)

344
345
346
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)

347
348
349
        assert param.size() == loaded_weight.size(), (
            f"Tried to load weights of size {loaded_weight.size()}"
            f"to a parameter of size {param.size()}")
350
351
        param.data.copy_(loaded_weight)

352
353
354
    def forward(
        self, x: torch.Tensor
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
355
        bias = self.bias if not self.skip_bias_add else None
356
        assert self.quant_method is not None
357
        output = self.quant_method.apply(self, x, bias)
358
        output_bias = self.bias if self.skip_bias_add else None
359
360
        if not self.return_bias:
            return output
361
362
        return output, output_bias

363
364
365
366
367
368
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        return s

369

370
@CustomOp.register("column_parallel_linear")
371
class ColumnParallelLinear(LinearBase):
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
    """Linear layer with column parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its second dimension as A = [A_1, ..., A_p].

    Args:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y_i = XA_i
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
388
        quant_config: Quantization configure.
James Fleming's avatar
James Fleming committed
389
390
        output_sizes: list of output sizes packed into one output, like for QKV
                       the list would be size 3.
391
        prefix: The name of the layer in the state dict, including all parents
392
393
394
                        (e.g. model.layers.0.qkv_proj)
        return_bias: If true, return bias together with outputs in forward pass.
        disable_tp: If true, weights matrix won't be sharded through tp rank.
395
396
    """

397
398
399
400
401
402
403
404
405
406
407
408
409
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
        output_sizes: Optional[list[int]] = None,
        prefix: str = "",
        *,
        return_bias: bool = True,
410
        disable_tp: bool = False,
411
    ):
412
        # Divide the weight matrix along the last dimension.
413
414
415
416
        self.tp_rank = (get_tensor_model_parallel_rank()
                        if not disable_tp else 0)
        self.tp_size = (get_tensor_model_parallel_world_size()
                        if not disable_tp else 1)
417
418
        self.input_size_per_partition = input_size
        self.output_size_per_partition = divide(output_size, self.tp_size)
419
420
421
422
        self.output_partition_sizes = [self.output_size_per_partition]
        # If QKV or MergedColumn, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = [
423
                divide(output_size, self.tp_size)
424
425
426
                for output_size in self.output_sizes
            ]

427
428
429
430
431
432
        super().__init__(input_size,
                         output_size,
                         skip_bias_add,
                         params_dtype,
                         quant_config,
                         prefix,
433
434
                         return_bias=return_bias,
                         disable_tp=disable_tp)
435
436
437

        self.gather_output = gather_output

James Fleming's avatar
James Fleming committed
438
439
        if output_sizes is None:
            output_sizes = [output_size]
440

441
        assert self.quant_method is not None
442
443
        self.quant_method.create_weights(
            layer=self,
444
            input_size_per_partition=self.input_size_per_partition,
445
446
447
448
            output_partition_sizes=self.output_partition_sizes,
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
449
450
            weight_loader=(
                self.weight_loader_v2 if self.quant_method.__class__.__name__
451
                in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
452
453
454
455
456
457
458
459
460
461
        if bias:
            self.bias = Parameter(
                torch.empty(self.output_size_per_partition,
                            dtype=params_dtype))
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
        else:
            self.register_parameter("bias", None)
462
        self.update_param_tp_status()
463
464

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
465

466
        output_dim = getattr(param, "output_dim", None)
467

468
469
470
471
472
473
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

474
475
476
477
478
479
480
481
        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
482
483
            final_shape = list(loaded_weight.shape)
            if output_dim is not None:
484
485
486
                assert final_shape[output_dim] % self.tp_size == 0
                final_shape[output_dim] = (final_shape[output_dim] //
                                           self.tp_size)
487
            param.materialize(final_shape, dtype=loaded_weight.dtype)
488

489
        param_data = param.data
490
        if output_dim is not None and not is_sharded_weight:
491
            shard_size = param_data.shape[output_dim]
492
            start_idx = self.tp_rank * shard_size
493
494
            loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                 shard_size)
495
496
497
498
499

        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)
500

501
502
503
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

504
505
    def weight_loader_v2(self, param: BasevLLMParameter,
                         loaded_weight: torch.Tensor):
506
507
508
509
510
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)
511
512
        param.load_column_parallel_weight(loaded_weight=loaded_weight)

513
514
515
    def forward(
        self, input_
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
516
517
518
        bias = self.bias if not self.skip_bias_add else None

        # Matrix multiply.
519
        assert self.quant_method is not None
520
        output_parallel = self.quant_method.apply(self, input_, bias)
521
        if self.gather_output and self.tp_size > 1:
522
523
524
525
526
            # All-gather across the partitions.
            output = tensor_model_parallel_all_gather(output_parallel)
        else:
            output = output_parallel
        output_bias = self.bias if self.skip_bias_add else None
527
528
        if not self.return_bias:
            return output
529
530
        return output, output_bias

531
532
533
534
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size_per_partition}"
        s += f", bias={self.bias is not None}"
535
        s += f", tp_size={self.tp_size}"
536
537
538
        s += f", gather_output={self.gather_output}"
        return s

539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557

class MergedColumnParallelLinear(ColumnParallelLinear):
    """Packed linear layers with column parallelism.

    Similar to ColumnParallelLinear, but the weight matrix is concatenated
    along the output dimension. When the weight matrix is loaded, the
    different partitions are sharded separately.

    Args:
        input_size: input dimension of the linear layer.
        output_sizes: list of output dimensions of the linear layer.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make the output
                       available to all GPUs, otherwise, every GPU will have
                       its own output.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
558
        quant_config: Quantization configure.
559
560
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
561
        return_bias: If true, return bias together with outputs in forward pass.
562
563
        disable_tp: If true, all weights matrix won't be sharded, this layer
                    will be treated as a "Replicated" MergedLinear.
564
565
    """

566
567
568
569
570
571
572
573
574
575
576
577
    def __init__(
        self,
        input_size: int,
        output_sizes: list[int],
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        *,
        return_bias: bool = True,
578
        disable_tp: bool = False,
579
    ):
580
        self.output_sizes = output_sizes
581
582
583
584
        self.tp_size = (get_tensor_model_parallel_world_size()
                        if not disable_tp else 1)
        self.tp_rank = (get_tensor_model_parallel_rank()
                        if not disable_tp else 0)
585
586
587

        assert all(output_size % self.tp_size == 0
                   for output_size in output_sizes)
588
589
590
591
592
593
        super().__init__(input_size=input_size,
                         output_size=sum(output_sizes),
                         bias=bias,
                         gather_output=gather_output,
                         skip_bias_add=skip_bias_add,
                         params_dtype=params_dtype,
594
                         quant_config=quant_config,
595
                         prefix=prefix,
596
597
                         return_bias=return_bias,
                         disable_tp=disable_tp)
598
599
600
601
602

    def weight_loader(self,
                      param: Parameter,
                      loaded_weight: torch.Tensor,
                      loaded_shard_id: Optional[int] = None):
James Fleming's avatar
James Fleming committed
603

604
605
606
607
608
        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
609
610
611
612
613
614
615
616
            if loaded_shard_id is not None:
                param.data[loaded_shard_id].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
                param.shard_weight_type = {
                    i: loaded_weight.item()
                    for i, _ in enumerate(self.output_sizes)
                }
617
618
            return

619
620
621
        if is_gguf_weight:

            output_dim = getattr(param, "output_dim", None)
622
623
            shard_size = loaded_weight.size(output_dim) // self.tp_size
            start_idx = self.tp_rank * shard_size
624

625
626
627
628
629
630
631
            if loaded_shard_id is not None:
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
632

633
634
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
635
636
        # Special case for per-tensor scale to load scalar into fused array.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
637

638
        if loaded_shard_id is None:
639
640
            # Loaded weight is already fused on disk (mlp).
            # (e.g., Phi-3's gate_up_proj).
641
            if output_dim is None:
642
                if needs_scalar_to_array:
643
644
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
                        param_data, loaded_weight, 0)
645

646
647
648
649
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            current_shard_offset = 0
650
651
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)
652
            shard_offsets: list[tuple[int, int, int]] = []
653
654
655
656
657
            for i, output_size in enumerate(self.output_sizes):
                shard_offsets.append((i, current_shard_offset, output_size))
                current_shard_offset += output_size
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
658
                # Special case for Quantization.
659
660
661
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
662
663
                    shard_size = shard_size // param.packed_factor
                    shard_offset = shard_offset // param.packed_factor
664
                    # Special case for Marlin.
665
666
                    shard_size, shard_offset = adjust_marlin_shard(
                        param, shard_size, shard_offset)
667

668
669
670
                shard_size, shard_offset = adjust_bitblas_shard(
                    param, shard_size, shard_offset)

671
                if use_bitsandbytes_4bit:
672
673
674
675
676
677
678
679
680
                    index = list(itertools.accumulate([0] + self.output_sizes))
                    orig_offsets = {
                        str(i): (index[i], size)
                        for i, size in enumerate(self.output_sizes)
                    }
                    orig_offsets["total"] = (self.output_size, 0)
                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
                        param, orig_offsets, str(shard_id))

681
682
683
684
685
686
687
                loaded_weight_shard = loaded_weight.narrow(
                    output_dim, shard_offset, shard_size)
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id < len(self.output_sizes)
        if output_dim is not None:
688
689
690
            shard_offset = (sum(self.output_sizes[:loaded_shard_id]) //
                            self.tp_size)
            shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
691
            # Special case for quantization.
692
693
694
695
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
696
697
                shard_size = shard_size // param.packed_factor
                shard_offset = shard_offset // param.packed_factor
698
                # Special case for Marlin.
699
700
                shard_size, shard_offset = adjust_marlin_shard(
                    param, shard_size, shard_offset)
701
702
            shard_size, shard_offset = adjust_bitblas_shard(
                param, shard_size, shard_offset)
703

704
705
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)
706
707
708
709
710
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

711
            if use_bitsandbytes_4bit:
712
713
714
715
                shard_size = loaded_weight.shape[output_dim]
                shard_offset = loaded_weight.shape[output_dim] * \
                    loaded_shard_id

716
717
            param_data = param_data.narrow(output_dim, shard_offset,
                                           shard_size)
718
            start_idx = self.tp_rank * shard_size
719
            if not is_sharded_weight:
720
721
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)
722
723
724
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
725
726
                param_data, loaded_weight, loaded_shard_id)

727
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
728
729
730
731
732
733
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "MergedColumnParallelLinear, assume the weight is "
                    "the same for all partitions.")
734

735
736
737
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

738
739
740
741
742
743
744
745
746
747
748
749
750
    def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
                                           loaded_weight: torch.Tensor):
        """
        Handle special case for models where MLP layers are already
        fused on disk. In this case, we have no shard id. This function
        determmines the shard id by splitting these layers and then calls
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """

        current_shard_offset = 0
751
        shard_offsets: list[tuple[int, int, int]] = []
752
753
754
755
756
757
758
759
        for i, output_size in enumerate(self.output_sizes):
            shard_offsets.append((i, current_shard_offset, output_size))
            current_shard_offset += output_size

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
760
761
            if isinstance(param, (PackedColumnParameter, PackedvLLMParameter
                                  )) and param.packed_dim == param.output_dim:
762
763
                shard_size, shard_offset = \
                    param.adjust_shard_indexes_for_packing(
764
765
766
767
768
769
770
771
772
773
774
775
                    shard_size=shard_size, shard_offset=shard_offset)

            loaded_weight_shard = loaded_weight.narrow(param.output_dim,
                                                       shard_offset,
                                                       shard_size)
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

    def weight_loader_v2(self,
                         param: BasevLLMParameter,
                         loaded_weight: torch.Tensor,
                         loaded_shard_id: Optional[int] = None):
        if loaded_shard_id is None:
776
777
778
779
            if isinstance(param, PerTensorScaleParameter):
                param.load_merged_column_weight(loaded_weight=loaded_weight,
                                                shard_id=0)
                return
780
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
781
                param.load_merged_column_weight(loaded_weight=loaded_weight)
782
                return
783
            # TODO: @dsikka - move to parameter.py
784
785
786
787
788
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id < len(self.output_sizes)

789
790
791
792
793
794
795
796
797
798
799
        if isinstance(param, BlockQuantScaleParameter):
            from vllm.model_executor.layers.quantization.fp8 import (
                Fp8LinearMethod, Fp8MoEMethod)
            assert self.quant_method is not None
            assert isinstance(self.quant_method,
                              (Fp8LinearMethod, Fp8MoEMethod))
            weight_block_size = self.quant_method.quant_config.weight_block_size
            assert weight_block_size is not None
            block_n, _ = weight_block_size[0], weight_block_size[1]
            shard_offset = (
                (sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) //
800
                block_n) // self.tp_size
801
            shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) //
802
                          block_n // self.tp_size)
803
        else:
804
805
806
            shard_offset = sum(
                self.output_sizes[:loaded_shard_id]) // self.tp_size
            shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
807
808
809
810

        param.load_merged_column_weight(loaded_weight=loaded_weight,
                                        shard_id=loaded_shard_id,
                                        shard_offset=shard_offset,
811
812
                                        shard_size=shard_size,
                                        tp_rank=self.tp_rank)
813

814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835

class QKVParallelLinear(ColumnParallelLinear):
    """Linear layers for the attention's QKV transformation.

    Linear layers for the linear transformation of the query, key, and value
    vectors in the attention layer. The weight matrix is concatenated along
    the output dimension. The layer is parallelized along the head dimension.
    When the number of key/value heads is smaller than the number of query
    heads (e.g., multi-query/grouped-query attention), the key/value head may
    be replicated while the query heads are partitioned.

    Args:
        hidden_size: input hidden state size of the transformer.
        head_size: size of each attention head.
        total_num_heads: total number of attention query heads.
        total_num_kv_heads: total number of attention key/value heads. If
                            None, assume total_num_kv_heads = total_num_heads.
        bias: If true, add bias.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
836
        quant_config: Quantization configure.
837
838
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
839
        return_bias: If true, return bias together with outputs in forward pass.
840
        disable_tp: If true, weights matrix won't be sharded through tp rank.
841
842
    """

843
844
845
846
847
848
849
850
851
852
853
854
855
    def __init__(
        self,
        hidden_size: int,
        head_size: int,
        total_num_heads: int,
        total_num_kv_heads: Optional[int] = None,
        bias: bool = True,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        *,
        return_bias: bool = True,
856
        disable_tp: bool = False,
857
    ):
858
859
860
861
862
863
864
        self.hidden_size = hidden_size
        self.head_size = head_size
        self.total_num_heads = total_num_heads
        if total_num_kv_heads is None:
            total_num_kv_heads = total_num_heads
        self.total_num_kv_heads = total_num_kv_heads
        # Divide the weight matrix along the last dimension.
865
866
        tp_size = (get_tensor_model_parallel_world_size()
                   if not disable_tp else 1)
867
868
869
870
871
872
873
874
875
876
877
        self.num_heads = divide(self.total_num_heads, tp_size)
        if tp_size >= self.total_num_kv_heads:
            self.num_kv_heads = 1
            self.num_kv_head_replicas = divide(tp_size,
                                               self.total_num_kv_heads)
        else:
            self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
            self.num_kv_head_replicas = 1
        input_size = self.hidden_size
        output_size = (self.num_heads +
                       2 * self.num_kv_heads) * tp_size * self.head_size
878
879
880
881
        self.output_sizes = [
            self.num_heads * self.head_size * tp_size,  # q_proj
            self.num_kv_heads * self.head_size * tp_size,  # k_proj
            self.num_kv_heads * self.head_size * tp_size,  # v_proj 
James Fleming's avatar
James Fleming committed
882
883
        ]

884
885
886
887
888
889
        super().__init__(input_size=input_size,
                         output_size=output_size,
                         bias=bias,
                         gather_output=False,
                         skip_bias_add=skip_bias_add,
                         params_dtype=params_dtype,
890
                         quant_config=quant_config,
891
                         prefix=prefix,
892
893
                         return_bias=return_bias,
                         disable_tp=disable_tp)
894

895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
    def _get_shard_offset_mapping(self, loaded_shard_id: str):
        shard_offset_mapping = {
            "q": 0,
            "k": self.num_heads * self.head_size,
            "v": (self.num_heads + self.num_kv_heads) * self.head_size,
            "total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size
        }
        return shard_offset_mapping.get(loaded_shard_id)

    def _get_shard_size_mapping(self, loaded_shard_id: str):
        shard_size_mapping = {
            "q": self.num_heads * self.head_size,
            "k": self.num_kv_heads * self.head_size,
            "v": self.num_kv_heads * self.head_size,
        }
        return shard_size_mapping.get(loaded_shard_id)

    def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
                                           loaded_weight: torch.Tensor):
        """
        Handle special case for models where QKV layers are already 
        fused on disk. In this case, we have no shard id. This function
        determmines the shard id by splitting these layers and then calls
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """
        shard_offsets = [
            # (shard_id, shard_offset, shard_size)
            ("q", 0, self.total_num_heads * self.head_size),
            ("k", self.total_num_heads * self.head_size,
             self.total_num_kv_heads * self.head_size),
            ("v",
             (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
             self.total_num_kv_heads * self.head_size),
        ]

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
937
938
            if isinstance(param, (PackedColumnParameter, PackedvLLMParameter
                                  )) and param.packed_dim == param.output_dim:
939
940
                shard_size, shard_offset = \
                    param.adjust_shard_indexes_for_packing(
941
942
943
944
945
946
947
948
949
950
951
952
                    shard_size=shard_size, shard_offset=shard_offset)

            loaded_weight_shard = loaded_weight.narrow(param.output_dim,
                                                       shard_offset,
                                                       shard_size)
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

    def weight_loader_v2(self,
                         param: BasevLLMParameter,
                         loaded_weight: torch.Tensor,
                         loaded_shard_id: Optional[str] = None):
        if loaded_shard_id is None:  # special case for certain models
953
            if isinstance(param, PerTensorScaleParameter):
954
955
956
                param.load_qkv_weight(loaded_weight=loaded_weight,
                                      shard_id=0,
                                      tp_rank=self.tp_rank)
957
                return
958
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
959
960
                param.load_qkv_weight(loaded_weight=loaded_weight,
                                      tp_rank=self.tp_rank)
961
                return
962
            # TODO: @dsikka - move to parameter.py
963
964
965
966
967
968
969
970
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id in ["q", "k", "v"]

        shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
        shard_size = self._get_shard_size_mapping(loaded_shard_id)

971
972
973
974
975
976
977
978
979
        # Note(simon): This is needed for Qwen3's fp8 quantization.
        if isinstance(param, BlockQuantScaleParameter):
            assert self.quant_method is not None
            assert hasattr(self.quant_method, "quant_config")
            weight_block_size = self.quant_method.quant_config.weight_block_size
            block_n, _ = weight_block_size[0], weight_block_size[1]
            shard_offset = (shard_offset + block_n - 1) // block_n
            shard_size = (shard_size + block_n - 1) // block_n

980
981
982
983
        param.load_qkv_weight(loaded_weight=loaded_weight,
                              num_heads=self.num_kv_head_replicas,
                              shard_id=loaded_shard_id,
                              shard_offset=shard_offset,
984
985
                              shard_size=shard_size,
                              tp_rank=self.tp_rank)
986

987
988
989
990
    def weight_loader(self,
                      param: Parameter,
                      loaded_weight: torch.Tensor,
                      loaded_shard_id: Optional[str] = None):
991
992
993
994
995

        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
996
        if is_gguf_weight_type:
997
            idx_map = {"q": 0, "k": 1, "v": 2}
998
999
1000
1001
1002
1003
1004
1005
            if loaded_shard_id is not None:
                param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
                param.shard_weight_type = {
                    k: loaded_weight.item()
                    for k in idx_map
                }
1006
1007
            return

1008
1009
        if is_gguf_weight:
            output_dim = getattr(param, "output_dim", None)
1010
1011
            shard_size = loaded_weight.size(output_dim) // self.tp_size
            start_idx = self.tp_rank * shard_size
1012

1013
1014
1015
1016
1017
1018
1019
            if loaded_shard_id is not None:
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
1020

1021
1022
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
1023

1024
1025
        # Special case for per-tensor scales in fused case.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
1026

1027
        if loaded_shard_id is None:
1028
1029
            # Loaded weight is already fused on disk (qkv).
            # (e.g., Phi-3's qkv_proj).
1030
            if output_dim is None:
1031
                if needs_scalar_to_array:
1032
1033
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
                        param_data, loaded_weight, 0)
1034

1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            shard_offsets = [
                # (shard_id, shard_offset, shard_size)
                ("q", 0, self.total_num_heads * self.head_size),
                ("k", self.total_num_heads * self.head_size,
                 self.total_num_kv_heads * self.head_size),
                ("v", (self.total_num_heads + self.total_num_kv_heads) *
                 self.head_size, self.total_num_kv_heads * self.head_size),
            ]
1046
1047
1048
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)

1049
1050
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
1051
                # Special case for Quantized Weights.
1052
1053
1054
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
1055
1056
                    shard_size = shard_size // param.packed_factor
                    shard_offset = shard_offset // param.packed_factor
1057

1058
                    # Special case for Marlin.
1059
1060
1061
                    shard_size, shard_offset = adjust_marlin_shard(
                        param, shard_size, shard_offset)

1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
                if use_bitsandbytes_4bit:
                    orig_qkv_offsets = {
                        "q": (0, self.total_num_heads * self.head_size),
                        "k": (self.total_num_heads * self.head_size,
                              self.total_num_kv_heads * self.head_size),
                        "v":
                        ((self.total_num_heads + self.total_num_kv_heads) *
                         self.head_size,
                         self.total_num_kv_heads * self.head_size),
                        "total":
                        ((self.total_num_heads + 2 * self.total_num_kv_heads) *
                         self.head_size, 0)
                    }

                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
                        param, orig_qkv_offsets, shard_id)

1079
1080
1081
1082
1083
1084
                loaded_weight_shard = loaded_weight.narrow(
                    output_dim, shard_offset, shard_size)
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id in ["q", "k", "v"]
1085
1086

        # If output dim is defined, use the default loading process.
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
        if output_dim is not None:
            if loaded_shard_id == "q":
                shard_offset = 0
                shard_size = self.num_heads * self.head_size
            elif loaded_shard_id == "k":
                shard_offset = self.num_heads * self.head_size
                shard_size = self.num_kv_heads * self.head_size
            elif loaded_shard_id == "v":
                shard_offset = (self.num_heads +
                                self.num_kv_heads) * self.head_size
                shard_size = self.num_kv_heads * self.head_size
1098
            # Special case for Quantized Weights.
1099
1100
1101
1102
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
1103
1104
                shard_size = shard_size // param.packed_factor
                shard_offset = shard_offset // param.packed_factor
1105

1106
                # Special case for Marlin.
1107
1108
1109
                shard_size, shard_offset = adjust_marlin_shard(
                    param, shard_size, shard_offset)

1110
1111
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)
1112
1113
1114
1115
1116
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

1117
            if use_bitsandbytes_4bit:
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
                orig_qkv_offsets = {
                    "q": (0, self.num_heads * self.head_size),
                    "k": (self.num_heads * self.head_size,
                          self.num_kv_heads * self.head_size),
                    "v":
                    ((self.num_heads + self.num_kv_heads) * self.head_size,
                     self.num_kv_heads * self.head_size),
                    "total":
                    ((self.num_heads + 2 * self.num_kv_heads) * self.head_size,
                     0)
                }
1129
                shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
1130
1131
                    param, orig_qkv_offsets, loaded_shard_id)

1132
1133
            param_data = param_data.narrow(output_dim, shard_offset,
                                           shard_size)
1134
            if loaded_shard_id == "q":
1135
                shard_id = self.tp_rank
1136
            else:
1137
                shard_id = self.tp_rank // self.num_kv_head_replicas
1138
            start_idx = shard_id * shard_size
1139

1140
            if not is_sharded_weight:
1141
1142
1143
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)

1144
1145
1146
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
1147
                param_data, loaded_weight, loaded_shard_id)
1148
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
1149
1150
1151
1152
1153
1154
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "QKVParallelLinear, assume the weight is the same "
                    "for all partitions.")
1155

1156
1157
1158
1159
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)


1160
@CustomOp.register("row_parallel_linear")
1161
class RowParallelLinear(LinearBase):
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
    """Linear layer with row parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its first dimension and X along its second dimension as:
               -   -
              | A_1 |
              | .   |
          A = | .   |        X = [X_1, ..., X_p]
              | .   |
              | A_p |
               -   -
    Arguments:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias. Note that bias is not parallelized.
        input_is_parallel: If true, we assume that the input is already
                           split across the GPUs and we do not split
                           again.
        skip_bias_add: This was added to enable performance optimization where
                       bias can be fused with other element-wise operations.
                       We skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
1184
1185
1186
        reduce_results: If true, call all-reduce on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y = X_iA_i
1187
        quant_config: Quantization configure.
1188
1189
1190
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.down_proj)
        return_bias: If true, return bias together with outputs in forward pass.
1191
        disable_tp: If true, weights matrix won't be sharded through tp rank.
1192
1193
    """

1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        input_is_parallel: bool = True,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        reduce_results: bool = True,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        *,
        return_bias: bool = True,
1207
        disable_tp: bool = False,
1208
    ):
1209
        # Divide the weight matrix along the first dimension.
1210
1211
1212
1213
        self.tp_rank = (get_tensor_model_parallel_rank()
                        if not disable_tp else 0)
        self.tp_size = (get_tensor_model_parallel_world_size()
                        if not disable_tp else 1)
1214
1215
1216
1217
        self.input_size_per_partition = divide(input_size, self.tp_size)
        self.output_size_per_partition = output_size
        self.output_partition_sizes = [output_size]

1218
1219
1220
1221
1222
1223
        super().__init__(input_size,
                         output_size,
                         skip_bias_add,
                         params_dtype,
                         quant_config,
                         prefix,
1224
1225
                         return_bias=return_bias,
                         disable_tp=disable_tp)
1226

1227
1228
1229
        self.input_is_parallel = input_is_parallel
        self.reduce_results = reduce_results

1230
        assert self.quant_method is not None
1231
1232
1233
        self.quant_method.create_weights(
            layer=self,
            input_size_per_partition=self.input_size_per_partition,
1234
            output_partition_sizes=self.output_partition_sizes,
1235
1236
1237
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
1238
1239
            weight_loader=(
                self.weight_loader_v2 if self.quant_method.__class__.__name__
1240
                in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
1241
1242
1243
1244
1245
1246
        if not reduce_results and (bias and not skip_bias_add):
            raise ValueError("When not reduce the results, adding bias to the "
                             "results can lead to incorrect results")

        if bias:
            self.bias = Parameter(
1247
                torch.empty(self.output_size, dtype=params_dtype))
1248
1249
1250
1251
1252
1253
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
        else:
            self.register_parameter("bias", None)
1254
        self.update_param_tp_status()
1255
1256
1257

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        input_dim = getattr(param, "input_dim", None)
1258
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1259
1260
1261
1262
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
1263
1264
1265
1266
1267
1268
1269
1270
1271

        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
1272
1273
            weight_shape = list(loaded_weight.shape)
            if input_dim:
1274
1275
                weight_shape[input_dim] = (weight_shape[input_dim] //
                                           self.tp_size)
1276
            param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
1277

1278
        param_data = param.data
1279
        if input_dim is not None and not is_sharded_weight:
1280
            shard_size = param_data.shape[input_dim]
1281
            start_idx = self.tp_rank * shard_size
1282
1283
            loaded_weight = loaded_weight.narrow(input_dim, start_idx,
                                                 shard_size)
1284

1285
1286
1287
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
1288
1289
            loaded_weight = loaded_weight.reshape(1)

1290
1291
1292
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

1293
1294
    def weight_loader_v2(self, param: BasevLLMParameter,
                         loaded_weight: torch.Tensor):
1295
1296
1297
1298
1299
1300
1301

        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)

1302
1303
        param.load_row_parallel_weight(loaded_weight=loaded_weight)

1304
1305
1306
    def forward(
        self, input_
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
1307
1308
1309
1310
1311
        if self.input_is_parallel:
            input_parallel = input_
        else:
            splitted_input = split_tensor_along_last_dim(
                input_, num_partitions=self.tp_size)
1312
            input_parallel = splitted_input[self.tp_rank].contiguous()
1313
1314

        # Matrix multiply.
1315
        assert self.quant_method is not None
1316
1317
1318
1319
1320
1321
        # Only fuse bias add into GEMM for rank 0 (this ensures that
        # bias will not get added more than once in TP>1 case)
        bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
        output_parallel = self.quant_method.apply(self,
                                                  input_parallel,
                                                  bias=bias_)
1322
        if self.reduce_results and self.tp_size > 1:
1323
            output = tensor_model_parallel_all_reduce(output_parallel)
1324
        else:
1325
1326
1327
            output = output_parallel

        output_bias = self.bias if self.skip_bias_add else None
1328

1329
1330
        if not self.return_bias:
            return output
1331
        return output, output_bias
1332
1333

    def extra_repr(self) -> str:
1334
        s = f"in_features={self.input_size_per_partition}"
1335
1336
1337
1338
1339
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={self.tp_size}"
        s += f", reduce_results={self.reduce_results}"
        return s
1340
1341


1342
@CustomOp.register("qkv_cross_parallel_linear")
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
class QKVCrossParallelLinear(LinearBase):
    """Linear layers for efficient cross-attention's QKV transformation.

    Args:
        hidden_size: input hidden state size of the transformer.
        head_size: size of each attention head.
        total_num_heads: total number of attention query heads.
        total_num_kv_heads: total number of attention key/value heads. If
                            None, assume total_num_kv_heads = total_num_heads.
        bias: If true, add bias.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
        quant_config: Quantization configure.
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
    """
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371

    def __init__(self,
                 hidden_size: int,
                 head_size: int,
                 total_num_heads: int,
                 total_num_kv_heads: Optional[int] = None,
                 bias: bool = True,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
        # input_size and output_size are not used, just for alignment
        input_size = hidden_size
        output_size = (total_num_heads + (total_num_kv_heads or 0)) * head_size
        super().__init__(input_size=input_size,
                         output_size=output_size,
                         skip_bias_add=skip_bias_add,
                         params_dtype=params_dtype,
                         quant_config=quant_config,
                         prefix=prefix)

        self.quant_config = quant_config

1384
        # Empty placeholders for loading as a single module.
1385
1386
1387
1388
1389
1390
1391
1392
1393
        placeholder_size = 0
        assert self.quant_method is not None
        self.quant_method.create_weights(self,
                                         placeholder_size, [placeholder_size],
                                         placeholder_size,
                                         placeholder_size,
                                         self.params_dtype,
                                         weight_loader=self.weight_loader)

1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
        # Use a dictionary to avoid submodules parameters auto-registration:
        # drop-in replacement for a `QKVParallelLinear` module.
        self.proj = dict()
        self.proj["q_proj_decoder"] = ColumnParallelLinear(
            input_size=hidden_size,
            output_size=total_num_heads * head_size,
            bias=bias,
            quant_config=quant_config,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            prefix=f"{prefix}.q_proj_decoder")

        self.proj["kv_proj_encoder"] = QKVParallelLinear(
            hidden_size=hidden_size,
            head_size=head_size,
            total_num_heads=0,
            total_num_kv_heads=total_num_kv_heads,
            bias=bias,
            quant_config=quant_config,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            prefix=f"{prefix}.kv_proj_encoder")

        # `kv_proj_encoder.num_kv_heads` accounts for sharding with tp>1.
1418
        self.q_size = self.q_proj_decoder.output_size_per_partition
1419
1420
1421
1422
1423
        self.kv_size = self.kv_proj_encoder.num_kv_heads * head_size

        if bias:
            self.bias = torch.nn.Parameter()
            set_weight_attrs(self.bias, {
1424
                "output_dim": 0,
1425
                "weight_loader": self.weight_loader_v1,
1426
            })
1427
1428
        else:
            self.bias = None
1429

1430
1431
1432
1433
1434
    def process_weights_after_loading(self):
        for layer in self.proj.values():
            if self.quant_method is not None:
                self.quant_method.process_weights_after_loading(layer)

1435
    @property
1436
1437
1438
    def q_proj_decoder(self) -> ColumnParallelLinear:
        layer = self.proj["q_proj_decoder"]
        for name, param in self.named_parameters():
1439
1440
1441
1442
1443
            target_param = getattr(layer, name, None)
            if target_param is not None:
                self.sync_weight_attrs(param,
                                       target_param,
                                       mode="q_proj_decoder")
1444
        return layer
1445
1446

    @property
1447
1448
1449
    def kv_proj_encoder(self) -> QKVParallelLinear:
        layer = self.proj["kv_proj_encoder"]
        for name, param in self.named_parameters():
1450
1451
1452
1453
1454
            target_param = getattr(layer, name, None)
            if target_param is not None:
                self.sync_weight_attrs(param,
                                       target_param,
                                       mode="kv_proj_encoder")
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
        return layer

    def sync_weight_attrs(
        self,
        src_param: nn.Parameter,
        tgt_param: nn.Parameter,
        mode: Literal["q_proj_decoder", "kv_proj_encoder"],
    ):
        missing_attrs_dict = {
            k: getattr(src_param, k)
1465
1466
            for k in (set(vars(src_param).keys()) -
                      set(vars(tgt_param).keys()))
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
        }
        # TODO(Isotr0py): handle bitsandbytes 8bit
        use_bitsandbytes_4bit = getattr(src_param, "use_bitsandbytes_4bit",
                                        False)
        if (missing_attrs_dict and use_bitsandbytes_4bit):
            q_proj_attrs, kv_proj_attrs = left_shift_bitsandbytes_4bit_shard(
                missing_attrs_dict)
            if mode == "q_proj_decoder":
                set_weight_attrs(tgt_param, q_proj_attrs)
            elif mode == "kv_proj_encoder":
                set_weight_attrs(tgt_param, kv_proj_attrs)
        else:
            set_weight_attrs(tgt_param, missing_attrs_dict)
1480

1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
    def _is_same_param(
        self,
        src_param: torch.nn.Parameter,
        map_param: torch.nn.Parameter,
    ) -> bool:
        """Check if two parameters are exactly pointing to same things."""
        # ignore weight_loader because it's always different
        key_to_ignore = ["weight_loader", "_weight_loader"]
        has_same_type_name = type(src_param) is type(map_param)
        src_param_attrs = {
            k: v
            for k, v in src_param.__dict__.items() if k not in key_to_ignore
        }
        map_param_attrs = {
            k: v
            for k, v in map_param.__dict__.items() if k not in key_to_ignore
        }
        has_same_attrs = src_param_attrs == map_param_attrs
        return has_same_type_name and has_same_attrs

    def select_proj_params(
        self,
        layer: nn.Module,
        param: nn.Parameter,
    ) -> nn.Parameter:
        """
        Given the placeholder param, 
        return the corresponding param in the proj layers.
        """
        target_param_list = [
            v for _, v in layer.named_parameters()
            if self._is_same_param(param, v)
        ]
        assert len(target_param_list) == 1
        target_param = target_param_list[0]
        return target_param

    def forward(  # type: ignore[override]
        self,
        decoder_hidden_states: torch.Tensor,
        encoder_hidden_states: torch.Tensor,
    ) -> tuple[torch.Tensor, ...]:
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
        q, _ = self.q_proj_decoder(decoder_hidden_states)
        if encoder_hidden_states is None:
            # Encoder KV already cached.
            k = None
            v = None
        else:
            # Prefill phase, encoder KV cached here.
            kv_enc, _ = self.kv_proj_encoder(encoder_hidden_states)
            # Split kv in half
            k, v = kv_enc.split(self.kv_size, dim=-1)
        return q, k, v

1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
    def weight_loader_v1(self,
                         param: torch.nn.Parameter,
                         loaded_weight: torch.Tensor,
                         loaded_shard_id: Optional[str] = None):
        # just like all other parameters, does not yet
        # support loading bias with weight_loader_v2
        layer = (self.q_proj_decoder
                 if loaded_shard_id == "q" else self.kv_proj_encoder)
        target_param = self.select_proj_params(layer, param)
        shard_id_args = (loaded_shard_id, ) if loaded_shard_id != "q" else ()
        layer.weight_loader(target_param, loaded_weight, *shard_id_args)

1547
1548
1549
1550
1551
1552
1553
1554
    def weight_loader(self,
                      param: torch.nn.Parameter,
                      loaded_weight: torch.Tensor,
                      loaded_shard_id: Optional[str] = None):
        layer = (self.q_proj_decoder
                 if loaded_shard_id == "q" else self.kv_proj_encoder)
        target_param = self.select_proj_params(layer, param)
        shard_id_args = (loaded_shard_id, ) if loaded_shard_id != "q" else ()
1555
1556
1557
1558
        if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED:
            layer.weight_loader_v2(target_param, loaded_weight, *shard_id_args)
        else:
            layer.weight_loader(target_param, loaded_weight, *shard_id_args)
1559
1560
1561

    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
1562
        s += f", q_size={self.q_size}"
1563
1564
1565
1566
1567
        s += f", kv_size={self.kv_size}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={get_tensor_model_parallel_world_size()}"
        s += ", gather_output=False"
        return s