linear.py 67 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
from abc import abstractmethod
6
from typing import Any, Literal, Optional, Union
7
8

import torch
9
import torch.nn as nn
10
from torch.nn.parameter import Parameter, UninitializedParameter
11

12
13
14
15
16
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size,
                              split_tensor_along_last_dim,
                              tensor_model_parallel_all_gather,
                              tensor_model_parallel_all_reduce)
17
from vllm.logger import init_logger
18
from vllm.model_executor.custom_op import CustomOp
19
20
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig, QuantizeMethodBase)
21
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
22
# yapf: disable
23
from vllm.model_executor.parameter import (BasevLLMParameter,
24
                                           BlockQuantScaleParameter,
25
                                           PackedColumnParameter,
26
                                           PackedvLLMParameter,
27
28
                                           PerTensorScaleParameter,
                                           RowvLLMParameter)
29
# yapf: enable
30
from vllm.model_executor.utils import set_weight_attrs
31
from vllm.platforms import current_platform
32
from vllm.utils import GiB_bytes
33
34
35

logger = init_logger(__name__)

36
WEIGHT_LOADER_V2_SUPPORTED = [
37
    "CompressedTensorsLinearMethod",
38
    "CompressedTensorsLinearTransformMethod",
39
40
    "BitBLASLinearMethod",
    "GPTQBitBLASLinearMethod",
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
    "AWQMarlinLinearMethod",
    "AWQLinearMethod",
    "GPTQMarlinLinearMethod",
    "Fp8LinearMethod",
    "MarlinLinearMethod",
    "GPTQMarlin24LinearMethod",
    "TPUInt8LinearMethod",
    "GPTQLinearMethod",
    "FBGEMMFp8LinearMethod",
    "ModelOptFp8LinearMethod",
    "IPEXAWQLinearMethod",
    "IPEXGPTQLinearMethod",
    "HQQMarlinMethod",
    "QuarkLinearMethod",
    "ModelOptNvFp4LinearMethod",
56
    "PetitNvFp4LinearMethod",
57
]
58

59

60
61
62
63
64
65
66
67
68
def adjust_bitblas_shard(param, shard_size, shard_offset):
    bitblas_tile_size = getattr(param, "bitblas_tile_size", None)
    if bitblas_tile_size is not None:
        return (shard_size // bitblas_tile_size,
                shard_offset // bitblas_tile_size)

    return shard_size, shard_offset


69
70
71
72
73
74
75
76
def adjust_marlin_shard(param, shard_size, shard_offset):
    marlin_tile_size = getattr(param, "marlin_tile_size", None)
    if marlin_tile_size is None:
        return shard_size, shard_offset

    return shard_size * marlin_tile_size, shard_offset * marlin_tile_size


77
def adjust_bitsandbytes_4bit_shard(param: Parameter,
78
79
                                   shard_offsets: dict[str, tuple[int, int]],
                                   loaded_shard_id: str) -> tuple[int, int]:
80
81
    """Adjust the quantization offsets and sizes for BitsAndBytes sharding."""

82
83
    total, _ = shard_offsets["total"]
    orig_offset, orig_size = shard_offsets[loaded_shard_id]
84
85
86
87
88
89
90
91

    quantized_total = param.data.shape[0]
    quantized_offset = orig_offset * quantized_total // total
    quantized_size = orig_size * quantized_total // total

    return quantized_size, quantized_offset


92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
    """For fused modules (QKV and MLP) we have an array of length
    N that holds 1 scale for each "logical" matrix. So the param
    is an array of length N. The loaded_weight corresponds to 
    one of the shards on disk. Here, we slice the param based on 
    the shard_id for loading.
    """
    qkv_idxs = {"q": 0, "k": 1, "v": 2}

    if isinstance(shard_id, str):
        shard_id = qkv_idxs[shard_id]
    elif not isinstance(shard_id, int):
        raise ValueError(f"Unknown Shard Id {shard_id}")

    # AutoFP8 scales do not have a shape
    # compressed-tensors scales do have a shape
    if len(loaded_weight.shape) != 0:
        assert loaded_weight.shape[0] == 1
        loaded_weight = loaded_weight[0]

    return param[shard_id], loaded_weight


115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
# TODO(Isotr0py): We might need a more flexible structure to handle
# bitsandbytes shard offsets.
def left_shift_bitsandbytes_4bit_shard(bnb_weight_attrs: dict[str, Any]):
    """
    Separate the BitsAndBytes 4-bit shard.

    For example, given bnb weight attributes as below:
    {
        'bnb_shard_offsets': array([0, 4, 8, 16]), 
        'bnb_quant_state': {0: ..., 1: ..., 2: ...},
    }

    The function will return:
    {
        'bnb_shard_offsets': array([0, 4]), 
        'bnb_quant_state': {0: ...},
    }
    and
    {
        'bnb_shard_offsets': array([0, 4, 12]),
        'bnb_quant_state': {0: ..., 1: ...},
    }
    """
    shard_offsets = bnb_weight_attrs["bnb_shard_offsets"]
    offset_l = shard_offsets[:2]
    offset_r = shard_offsets[1:] - shard_offsets[1]
    quant_state_l = {0: bnb_weight_attrs["bnb_quant_state"][0]}
    quant_state_r = {
        i - 1: bnb_weight_attrs["bnb_quant_state"][i]
        for i in range(1,
                       len(shard_offsets) - 1)
    }
    left = dict(bnb_shard_offsets=offset_l, bnb_quant_state=quant_state_l)
    right = dict(bnb_shard_offsets=offset_r, bnb_quant_state=quant_state_r)
    return left, right


152
class LinearMethodBase(QuantizeMethodBase):
153
154
155
    """Base class for different (maybe quantized) linear methods."""

    @abstractmethod
156
157
    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
158
                       output_partition_sizes: list[int], input_size: int,
159
160
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
161
162
        """Create weights for a linear layer. 
           The weights will be set as attributes of the layer.
163

164
165
166
167
168
169
170
171
172
173
        Args:
            layer: The layer that is using the LinearMethodBase factory.
            input_size_per_partition: Size of the weight input dim on rank X.
            output_partition_sizes: Sizes of the output dim of each logical 
                weight on rank X. E.g., output_partition_sizes for QKVLinear
                is a list contains the width of Wq, Wk, Wv on rank X.
            input_size: Size of the input dim of the weight across all ranks.
            output_size: Size of the output dim of the weight across all ranks.
            params_dtype: Datatype of the parameters.
        """
174
175
176
        raise NotImplementedError

    @abstractmethod
177
178
179
180
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
181
182
        """Apply the weights in layer to the input tensor.
        Expects create_weights to have been called before on the layer."""
183
184
185
186
        raise NotImplementedError


class UnquantizedLinearMethod(LinearMethodBase):
187
    """Linear method without quantization."""
188

189
190
    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
191
                       output_partition_sizes: list[int], input_size: int,
192
193
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
        # This method creates unquantized linear weights.
        # The weights are not quantized, and they are not sharded.
        # The amount of memory allocated for the weights is
        # sum(output_partition_sizes) * input_size_per_partition.
        try:
            weight = Parameter(torch.empty(sum(output_partition_sizes),
                                           input_size_per_partition,
                                           dtype=params_dtype),
                               requires_grad=False)
        except torch.cuda.OutOfMemoryError as e:
            logger.error("Failed to create unquantized linear weights: %s", e)
            if torch.cuda.is_available():
                logger.debug("CUDA device: %s", torch.cuda.current_device())
                logger.debug("Allocated: %.2f GiB",
                             torch.cuda.memory_allocated() / GiB_bytes)
                logger.debug("Reserved: %.2f GiB",
                             torch.cuda.memory_reserved() / GiB_bytes)
            raise RuntimeError(
                "Failed to create unquantized linear weights. "
                "This may be caused by insufficient memory to allocate "
                "the weight.") from e
215
        set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
216
217
        layer.register_parameter("weight", weight)
        set_weight_attrs(weight, extra_weight_attrs)
218

219
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
220
221
222
223
        if current_platform.is_cpu():
            from vllm.model_executor.layers.utils import (
                dispatch_cpu_unquantized_gemm)
            dispatch_cpu_unquantized_gemm(layer, remove_weight=True)
224

225
226
227
228
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
229

230
        return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
231
232


233
class LinearBase(CustomOp):
234
    """Base linear layer.
235
236
237
238
239
240

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
241
        quant_config: Quantization configure.
242
        prefix: Prefix for parameter names.
243
        return_bias: If true, return bias together with outputs in forward pass.
244
        disable_tp: If true, tensor parallelism will be disabled for this layer.
245
246
247
248
249
250
251
252
    """

    def __init__(
        self,
        input_size: int,
        output_size: int,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
253
        quant_config: Optional[QuantizationConfig] = None,
254
        prefix: str = "",
255
256
        *,
        return_bias: bool = True,
257
        disable_tp: bool = False,
258
259
260
261
262
263
264
265
266
267
    ):
        super().__init__()

        # Keep input parameters
        self.input_size = input_size
        self.output_size = output_size
        self.skip_bias_add = skip_bias_add
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype
268
269
        self.quant_config = quant_config
        self.prefix = prefix
270
        if quant_config is None:
271
272
            self.quant_method: Optional[
                QuantizeMethodBase] = UnquantizedLinearMethod()
273
        else:
274
275
            self.quant_method = quant_config.get_quant_method(self,
                                                              prefix=prefix)
276
        self.return_bias = return_bias
277
278
279
280
281
282
        self.disable_tp = disable_tp
        self.tp_rank = (get_tensor_model_parallel_rank()
                        if not disable_tp else 0)
        self.tp_size = (get_tensor_model_parallel_world_size()
                        if not disable_tp else 1)

283
    def update_param_tp_status(self):
284
285
286
287
        for param in self.parameters():
            if isinstance(param, BasevLLMParameter):
                param.tp_rank = self.tp_rank
                param.tp_size = self.tp_size
288
289


290
@CustomOp.register("replicated_linear")
291
292
293
294
295
296
297
298
299
300
class ReplicatedLinear(LinearBase):
    """Replicated linear layer.

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        bias: If true, add bias.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
        quant_config: Quantization configure.
301
302
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
303
        return_bias: If true, return bias together with outputs in forward pass.
304
        disable_tp: Take no effect for replicated linear layers.
305
306
    """

307
308
309
310
311
312
313
314
315
316
317
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        *,
        return_bias: bool = True,
318
        disable_tp: bool = False,
319
    ):
320
321
322
323
324
        super().__init__(input_size,
                         output_size,
                         skip_bias_add,
                         params_dtype,
                         quant_config,
325
                         prefix=prefix,
326
327
                         return_bias=return_bias,
                         disable_tp=disable_tp)
328

329
330
        # All the linear layer supports quant method.
        assert self.quant_method is not None
331
        self.quant_method.create_weights(self,
332
                                         self.input_size, [self.output_size],
333
334
335
                                         self.input_size,
                                         self.output_size,
                                         self.params_dtype,
336
                                         weight_loader=self.weight_loader)
337

338
339
        if bias:
            self.bias = Parameter(
340
                torch.empty(self.output_size, dtype=self.params_dtype))
341
342
343
344
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
345
346
347
        else:
            self.register_parameter("bias", None)

348
349
350
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        # If the weight on disk does not have a shape, give it one
        # (such scales for AutoFp8).
351
352
353
354
355
356
357
358
359
360
361
        # Special case for GGUF

        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
            param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)

362
363
364
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)

365
366
367
        assert param.size() == loaded_weight.size(), (
            f"Tried to load weights of size {loaded_weight.size()}"
            f"to a parameter of size {param.size()}")
368
369
        param.data.copy_(loaded_weight)

370
371
372
    def forward(
        self, x: torch.Tensor
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
373
        bias = self.bias if not self.skip_bias_add else None
374
        assert self.quant_method is not None
375
        output = self.quant_method.apply(self, x, bias)
376
        output_bias = self.bias if self.skip_bias_add else None
377
378
        if not self.return_bias:
            return output
379
380
        return output, output_bias

381
382
383
384
385
386
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        return s

387

388
@CustomOp.register("column_parallel_linear")
389
class ColumnParallelLinear(LinearBase):
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
    """Linear layer with column parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its second dimension as A = [A_1, ..., A_p].

    Args:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y_i = XA_i
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
406
        quant_config: Quantization configure.
James Fleming's avatar
James Fleming committed
407
408
        output_sizes: list of output sizes packed into one output, like for QKV
                       the list would be size 3.
409
        prefix: The name of the layer in the state dict, including all parents
410
411
412
                        (e.g. model.layers.0.qkv_proj)
        return_bias: If true, return bias together with outputs in forward pass.
        disable_tp: If true, weights matrix won't be sharded through tp rank.
413
414
    """

415
416
417
418
419
420
421
422
423
424
425
426
427
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
        output_sizes: Optional[list[int]] = None,
        prefix: str = "",
        *,
        return_bias: bool = True,
428
        disable_tp: bool = False,
429
    ):
430
        # Divide the weight matrix along the last dimension.
431
432
433
434
        self.tp_rank = (get_tensor_model_parallel_rank()
                        if not disable_tp else 0)
        self.tp_size = (get_tensor_model_parallel_world_size()
                        if not disable_tp else 1)
435
436
        self.input_size_per_partition = input_size
        self.output_size_per_partition = divide(output_size, self.tp_size)
437
438
439
440
        self.output_partition_sizes = [self.output_size_per_partition]
        # If QKV or MergedColumn, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = [
441
                divide(output_size, self.tp_size)
442
443
444
                for output_size in self.output_sizes
            ]

445
446
447
448
449
450
        super().__init__(input_size,
                         output_size,
                         skip_bias_add,
                         params_dtype,
                         quant_config,
                         prefix,
451
452
                         return_bias=return_bias,
                         disable_tp=disable_tp)
453
454
455

        self.gather_output = gather_output

James Fleming's avatar
James Fleming committed
456
457
        if output_sizes is None:
            output_sizes = [output_size]
458

459
        assert self.quant_method is not None
460
461
        self.quant_method.create_weights(
            layer=self,
462
            input_size_per_partition=self.input_size_per_partition,
463
464
465
466
            output_partition_sizes=self.output_partition_sizes,
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
467
468
            weight_loader=(
                self.weight_loader_v2 if self.quant_method.__class__.__name__
469
                in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
470
471
472
473
474
475
476
477
478
479
        if bias:
            self.bias = Parameter(
                torch.empty(self.output_size_per_partition,
                            dtype=params_dtype))
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
        else:
            self.register_parameter("bias", None)
480
        self.update_param_tp_status()
481
482

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
483

484
        output_dim = getattr(param, "output_dim", None)
485

486
487
488
489
490
491
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

492
493
494
495
496
497
498
499
        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
500
501
            final_shape = list(loaded_weight.shape)
            if output_dim is not None:
502
503
504
                assert final_shape[output_dim] % self.tp_size == 0
                final_shape[output_dim] = (final_shape[output_dim] //
                                           self.tp_size)
505
            param.materialize(final_shape, dtype=loaded_weight.dtype)
506

507
        param_data = param.data
508
        if output_dim is not None and not is_sharded_weight:
509
            shard_size = param_data.shape[output_dim]
510
            start_idx = self.tp_rank * shard_size
511
512
            loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                 shard_size)
513
514
515
516
517

        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)
518

519
520
521
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

522
523
    def weight_loader_v2(self, param: BasevLLMParameter,
                         loaded_weight: torch.Tensor):
524
525
526
527
528
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)
529
530
        param.load_column_parallel_weight(loaded_weight=loaded_weight)

531
532
533
    def forward(
        self, input_
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
534
535
536
        bias = self.bias if not self.skip_bias_add else None

        # Matrix multiply.
537
        assert self.quant_method is not None
538
        output_parallel = self.quant_method.apply(self, input_, bias)
539
        if self.gather_output and self.tp_size > 1:
540
541
542
543
544
            # All-gather across the partitions.
            output = tensor_model_parallel_all_gather(output_parallel)
        else:
            output = output_parallel
        output_bias = self.bias if self.skip_bias_add else None
545
546
        if not self.return_bias:
            return output
547
548
        return output, output_bias

549
550
551
552
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size_per_partition}"
        s += f", bias={self.bias is not None}"
553
        s += f", tp_size={self.tp_size}"
554
555
556
        s += f", gather_output={self.gather_output}"
        return s

557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575

class MergedColumnParallelLinear(ColumnParallelLinear):
    """Packed linear layers with column parallelism.

    Similar to ColumnParallelLinear, but the weight matrix is concatenated
    along the output dimension. When the weight matrix is loaded, the
    different partitions are sharded separately.

    Args:
        input_size: input dimension of the linear layer.
        output_sizes: list of output dimensions of the linear layer.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make the output
                       available to all GPUs, otherwise, every GPU will have
                       its own output.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
576
        quant_config: Quantization configure.
577
578
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
579
        return_bias: If true, return bias together with outputs in forward pass.
580
581
        disable_tp: If true, all weights matrix won't be sharded, this layer
                    will be treated as a "Replicated" MergedLinear.
582
583
    """

584
585
586
587
588
589
590
591
592
593
594
595
    def __init__(
        self,
        input_size: int,
        output_sizes: list[int],
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        *,
        return_bias: bool = True,
596
        disable_tp: bool = False,
597
    ):
598
        self.output_sizes = output_sizes
599
600
601
602
        self.tp_size = (get_tensor_model_parallel_world_size()
                        if not disable_tp else 1)
        self.tp_rank = (get_tensor_model_parallel_rank()
                        if not disable_tp else 0)
603
604
605

        assert all(output_size % self.tp_size == 0
                   for output_size in output_sizes)
606
607
608
609
610
611
        super().__init__(input_size=input_size,
                         output_size=sum(output_sizes),
                         bias=bias,
                         gather_output=gather_output,
                         skip_bias_add=skip_bias_add,
                         params_dtype=params_dtype,
612
                         quant_config=quant_config,
613
                         prefix=prefix,
614
615
                         return_bias=return_bias,
                         disable_tp=disable_tp)
616
617
618
619
620

    def weight_loader(self,
                      param: Parameter,
                      loaded_weight: torch.Tensor,
                      loaded_shard_id: Optional[int] = None):
James Fleming's avatar
James Fleming committed
621

622
623
624
625
626
        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
627
628
629
630
631
632
633
634
            if loaded_shard_id is not None:
                param.data[loaded_shard_id].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
                param.shard_weight_type = {
                    i: loaded_weight.item()
                    for i, _ in enumerate(self.output_sizes)
                }
635
636
            return

637
638
639
        if is_gguf_weight:

            output_dim = getattr(param, "output_dim", None)
640
641
            shard_size = loaded_weight.size(output_dim) // self.tp_size
            start_idx = self.tp_rank * shard_size
642

643
644
645
646
647
648
649
            if loaded_shard_id is not None:
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
650

651
652
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
653
654
        # Special case for per-tensor scale to load scalar into fused array.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
655

656
        if loaded_shard_id is None:
657
658
            # Loaded weight is already fused on disk (mlp).
            # (e.g., Phi-3's gate_up_proj).
659
            if output_dim is None:
660
                if needs_scalar_to_array:
661
662
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
                        param_data, loaded_weight, 0)
663

664
665
666
667
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            current_shard_offset = 0
668
669
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)
670
            shard_offsets: list[tuple[int, int, int]] = []
671
672
673
674
675
            for i, output_size in enumerate(self.output_sizes):
                shard_offsets.append((i, current_shard_offset, output_size))
                current_shard_offset += output_size
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
676
                # Special case for Quantization.
677
678
679
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
680
681
                    shard_size = shard_size // param.packed_factor
                    shard_offset = shard_offset // param.packed_factor
682
                    # Special case for Marlin.
683
684
                    shard_size, shard_offset = adjust_marlin_shard(
                        param, shard_size, shard_offset)
685

686
687
688
                shard_size, shard_offset = adjust_bitblas_shard(
                    param, shard_size, shard_offset)

689
                if use_bitsandbytes_4bit:
690
691
692
693
694
695
696
697
698
                    index = list(itertools.accumulate([0] + self.output_sizes))
                    orig_offsets = {
                        str(i): (index[i], size)
                        for i, size in enumerate(self.output_sizes)
                    }
                    orig_offsets["total"] = (self.output_size, 0)
                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
                        param, orig_offsets, str(shard_id))

699
700
701
702
703
704
705
                loaded_weight_shard = loaded_weight.narrow(
                    output_dim, shard_offset, shard_size)
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id < len(self.output_sizes)
        if output_dim is not None:
706
707
708
            shard_offset = (sum(self.output_sizes[:loaded_shard_id]) //
                            self.tp_size)
            shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
709
            # Special case for quantization.
710
711
712
713
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
714
715
                shard_size = shard_size // param.packed_factor
                shard_offset = shard_offset // param.packed_factor
716
                # Special case for Marlin.
717
718
                shard_size, shard_offset = adjust_marlin_shard(
                    param, shard_size, shard_offset)
719
720
            shard_size, shard_offset = adjust_bitblas_shard(
                param, shard_size, shard_offset)
721

722
723
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)
724
725
726
727
728
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

729
            if use_bitsandbytes_4bit:
730
731
732
733
                shard_size = loaded_weight.shape[output_dim]
                shard_offset = loaded_weight.shape[output_dim] * \
                    loaded_shard_id

734
735
            param_data = param_data.narrow(output_dim, shard_offset,
                                           shard_size)
736
            start_idx = self.tp_rank * shard_size
737
            if not is_sharded_weight:
738
739
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)
740
741
742
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
743
744
                param_data, loaded_weight, loaded_shard_id)

745
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
746
747
748
749
750
751
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "MergedColumnParallelLinear, assume the weight is "
                    "the same for all partitions.")
752

753
754
755
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

756
757
758
759
760
    def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
                                           loaded_weight: torch.Tensor):
        """
        Handle special case for models where MLP layers are already
        fused on disk. In this case, we have no shard id. This function
761
        determines the shard id by splitting these layers and then calls
762
763
764
765
766
767
768
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """

        current_shard_offset = 0
769
        shard_offsets: list[tuple[int, int, int]] = []
770
771
772
773
774
775
776
777
        for i, output_size in enumerate(self.output_sizes):
            shard_offsets.append((i, current_shard_offset, output_size))
            current_shard_offset += output_size

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
778
779
            if isinstance(param, (PackedColumnParameter, PackedvLLMParameter
                                  )) and param.packed_dim == param.output_dim:
780
781
                shard_size, shard_offset = \
                    param.adjust_shard_indexes_for_packing(
782
783
784
785
786
787
788
789
790
791
792
793
                    shard_size=shard_size, shard_offset=shard_offset)

            loaded_weight_shard = loaded_weight.narrow(param.output_dim,
                                                       shard_offset,
                                                       shard_size)
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

    def weight_loader_v2(self,
                         param: BasevLLMParameter,
                         loaded_weight: torch.Tensor,
                         loaded_shard_id: Optional[int] = None):
        if loaded_shard_id is None:
794
795
796
797
            if isinstance(param, PerTensorScaleParameter):
                param.load_merged_column_weight(loaded_weight=loaded_weight,
                                                shard_id=0)
                return
798
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
799
                param.load_merged_column_weight(loaded_weight=loaded_weight)
800
                return
801
            # TODO: @dsikka - move to parameter.py
802
803
804
805
806
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id < len(self.output_sizes)

807
808
        if isinstance(param, BlockQuantScaleParameter):
            assert self.quant_method is not None
809
810
811
            # Assume the weight block size has been set by quant method
            assert hasattr(self, "weight_block_size")
            weight_block_size = self.weight_block_size
812
813
814
815
            assert weight_block_size is not None
            block_n, _ = weight_block_size[0], weight_block_size[1]
            shard_offset = (
                (sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) //
816
                block_n) // self.tp_size
817
            shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) //
818
                          block_n // self.tp_size)
819
        else:
820
821
822
            shard_offset = sum(
                self.output_sizes[:loaded_shard_id]) // self.tp_size
            shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
823
824
825
826

        param.load_merged_column_weight(loaded_weight=loaded_weight,
                                        shard_id=loaded_shard_id,
                                        shard_offset=shard_offset,
827
828
                                        shard_size=shard_size,
                                        tp_rank=self.tp_rank)
829

830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851

class QKVParallelLinear(ColumnParallelLinear):
    """Linear layers for the attention's QKV transformation.

    Linear layers for the linear transformation of the query, key, and value
    vectors in the attention layer. The weight matrix is concatenated along
    the output dimension. The layer is parallelized along the head dimension.
    When the number of key/value heads is smaller than the number of query
    heads (e.g., multi-query/grouped-query attention), the key/value head may
    be replicated while the query heads are partitioned.

    Args:
        hidden_size: input hidden state size of the transformer.
        head_size: size of each attention head.
        total_num_heads: total number of attention query heads.
        total_num_kv_heads: total number of attention key/value heads. If
                            None, assume total_num_kv_heads = total_num_heads.
        bias: If true, add bias.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
852
        quant_config: Quantization configure.
853
854
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
855
        return_bias: If true, return bias together with outputs in forward pass.
856
        disable_tp: If true, weights matrix won't be sharded through tp rank.
857
858
    """

859
860
861
862
863
864
865
866
867
868
869
870
871
    def __init__(
        self,
        hidden_size: int,
        head_size: int,
        total_num_heads: int,
        total_num_kv_heads: Optional[int] = None,
        bias: bool = True,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        *,
        return_bias: bool = True,
872
        disable_tp: bool = False,
873
    ):
874
875
876
877
878
879
880
        self.hidden_size = hidden_size
        self.head_size = head_size
        self.total_num_heads = total_num_heads
        if total_num_kv_heads is None:
            total_num_kv_heads = total_num_heads
        self.total_num_kv_heads = total_num_kv_heads
        # Divide the weight matrix along the last dimension.
881
882
        tp_size = (get_tensor_model_parallel_world_size()
                   if not disable_tp else 1)
883
884
885
886
887
888
889
890
891
892
893
        self.num_heads = divide(self.total_num_heads, tp_size)
        if tp_size >= self.total_num_kv_heads:
            self.num_kv_heads = 1
            self.num_kv_head_replicas = divide(tp_size,
                                               self.total_num_kv_heads)
        else:
            self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
            self.num_kv_head_replicas = 1
        input_size = self.hidden_size
        output_size = (self.num_heads +
                       2 * self.num_kv_heads) * tp_size * self.head_size
894
895
896
897
        self.output_sizes = [
            self.num_heads * self.head_size * tp_size,  # q_proj
            self.num_kv_heads * self.head_size * tp_size,  # k_proj
            self.num_kv_heads * self.head_size * tp_size,  # v_proj 
James Fleming's avatar
James Fleming committed
898
899
        ]

900
901
902
903
904
905
        super().__init__(input_size=input_size,
                         output_size=output_size,
                         bias=bias,
                         gather_output=False,
                         skip_bias_add=skip_bias_add,
                         params_dtype=params_dtype,
906
                         quant_config=quant_config,
907
                         prefix=prefix,
908
909
                         return_bias=return_bias,
                         disable_tp=disable_tp)
910

911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
    def _get_shard_offset_mapping(self, loaded_shard_id: str):
        shard_offset_mapping = {
            "q": 0,
            "k": self.num_heads * self.head_size,
            "v": (self.num_heads + self.num_kv_heads) * self.head_size,
            "total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size
        }
        return shard_offset_mapping.get(loaded_shard_id)

    def _get_shard_size_mapping(self, loaded_shard_id: str):
        shard_size_mapping = {
            "q": self.num_heads * self.head_size,
            "k": self.num_kv_heads * self.head_size,
            "v": self.num_kv_heads * self.head_size,
        }
        return shard_size_mapping.get(loaded_shard_id)

    def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
                                           loaded_weight: torch.Tensor):
        """
        Handle special case for models where QKV layers are already 
        fused on disk. In this case, we have no shard id. This function
933
        determines the shard id by splitting these layers and then calls
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """
        shard_offsets = [
            # (shard_id, shard_offset, shard_size)
            ("q", 0, self.total_num_heads * self.head_size),
            ("k", self.total_num_heads * self.head_size,
             self.total_num_kv_heads * self.head_size),
            ("v",
             (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
             self.total_num_kv_heads * self.head_size),
        ]

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
953
954
            if isinstance(param, (PackedColumnParameter, PackedvLLMParameter
                                  )) and param.packed_dim == param.output_dim:
955
956
                shard_size, shard_offset = \
                    param.adjust_shard_indexes_for_packing(
957
958
959
960
961
962
963
964
965
966
967
968
                    shard_size=shard_size, shard_offset=shard_offset)

            loaded_weight_shard = loaded_weight.narrow(param.output_dim,
                                                       shard_offset,
                                                       shard_size)
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

    def weight_loader_v2(self,
                         param: BasevLLMParameter,
                         loaded_weight: torch.Tensor,
                         loaded_shard_id: Optional[str] = None):
        if loaded_shard_id is None:  # special case for certain models
969
            if isinstance(param, PerTensorScaleParameter):
970
971
972
                param.load_qkv_weight(loaded_weight=loaded_weight,
                                      shard_id=0,
                                      tp_rank=self.tp_rank)
973
                return
974
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
975
976
                param.load_qkv_weight(loaded_weight=loaded_weight,
                                      tp_rank=self.tp_rank)
977
                return
978
            # TODO: @dsikka - move to parameter.py
979
980
981
982
983
984
985
986
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id in ["q", "k", "v"]

        shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
        shard_size = self._get_shard_size_mapping(loaded_shard_id)

987
988
989
        # Note(simon): This is needed for Qwen3's fp8 quantization.
        if isinstance(param, BlockQuantScaleParameter):
            assert self.quant_method is not None
990
991
992
993
            # Assume the weight block size has been set by quant method
            assert hasattr(self, "weight_block_size")
            weight_block_size = self.weight_block_size
            assert weight_block_size is not None
994
995
996
997
            block_n, _ = weight_block_size[0], weight_block_size[1]
            shard_offset = (shard_offset + block_n - 1) // block_n
            shard_size = (shard_size + block_n - 1) // block_n

998
999
1000
1001
        param.load_qkv_weight(loaded_weight=loaded_weight,
                              num_heads=self.num_kv_head_replicas,
                              shard_id=loaded_shard_id,
                              shard_offset=shard_offset,
1002
1003
                              shard_size=shard_size,
                              tp_rank=self.tp_rank)
1004

1005
1006
1007
1008
    def weight_loader(self,
                      param: Parameter,
                      loaded_weight: torch.Tensor,
                      loaded_shard_id: Optional[str] = None):
1009
1010
1011
1012
1013

        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
1014
        if is_gguf_weight_type:
1015
            idx_map = {"q": 0, "k": 1, "v": 2}
1016
1017
1018
1019
1020
1021
1022
1023
            if loaded_shard_id is not None:
                param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
                param.shard_weight_type = {
                    k: loaded_weight.item()
                    for k in idx_map
                }
1024
1025
            return

1026
1027
        if is_gguf_weight:
            output_dim = getattr(param, "output_dim", None)
1028
1029
            shard_size = loaded_weight.size(output_dim) // self.tp_size
            start_idx = self.tp_rank * shard_size
1030

1031
1032
1033
1034
1035
1036
1037
            if loaded_shard_id is not None:
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
1038

1039
1040
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
1041

1042
1043
        # Special case for per-tensor scales in fused case.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
1044

1045
        if loaded_shard_id is None:
1046
1047
            # Loaded weight is already fused on disk (qkv).
            # (e.g., Phi-3's qkv_proj).
1048
            if output_dim is None:
1049
                if needs_scalar_to_array:
1050
1051
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
                        param_data, loaded_weight, 0)
1052

1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            shard_offsets = [
                # (shard_id, shard_offset, shard_size)
                ("q", 0, self.total_num_heads * self.head_size),
                ("k", self.total_num_heads * self.head_size,
                 self.total_num_kv_heads * self.head_size),
                ("v", (self.total_num_heads + self.total_num_kv_heads) *
                 self.head_size, self.total_num_kv_heads * self.head_size),
            ]
1064
1065
1066
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)

1067
1068
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
1069
                # Special case for Quantized Weights.
1070
1071
1072
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
1073
1074
                    shard_size = shard_size // param.packed_factor
                    shard_offset = shard_offset // param.packed_factor
1075

1076
                    # Special case for Marlin.
1077
1078
1079
                    shard_size, shard_offset = adjust_marlin_shard(
                        param, shard_size, shard_offset)

1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
                if use_bitsandbytes_4bit:
                    orig_qkv_offsets = {
                        "q": (0, self.total_num_heads * self.head_size),
                        "k": (self.total_num_heads * self.head_size,
                              self.total_num_kv_heads * self.head_size),
                        "v":
                        ((self.total_num_heads + self.total_num_kv_heads) *
                         self.head_size,
                         self.total_num_kv_heads * self.head_size),
                        "total":
                        ((self.total_num_heads + 2 * self.total_num_kv_heads) *
                         self.head_size, 0)
                    }

                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
                        param, orig_qkv_offsets, shard_id)

1097
1098
1099
1100
1101
1102
                loaded_weight_shard = loaded_weight.narrow(
                    output_dim, shard_offset, shard_size)
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id in ["q", "k", "v"]
1103
1104

        # If output dim is defined, use the default loading process.
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
        if output_dim is not None:
            if loaded_shard_id == "q":
                shard_offset = 0
                shard_size = self.num_heads * self.head_size
            elif loaded_shard_id == "k":
                shard_offset = self.num_heads * self.head_size
                shard_size = self.num_kv_heads * self.head_size
            elif loaded_shard_id == "v":
                shard_offset = (self.num_heads +
                                self.num_kv_heads) * self.head_size
                shard_size = self.num_kv_heads * self.head_size
1116
            # Special case for Quantized Weights.
1117
1118
1119
1120
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
1121
1122
                shard_size = shard_size // param.packed_factor
                shard_offset = shard_offset // param.packed_factor
1123

1124
                # Special case for Marlin.
1125
1126
1127
                shard_size, shard_offset = adjust_marlin_shard(
                    param, shard_size, shard_offset)

1128
1129
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)
1130
1131
1132
1133
1134
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

1135
            if use_bitsandbytes_4bit:
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
                orig_qkv_offsets = {
                    "q": (0, self.num_heads * self.head_size),
                    "k": (self.num_heads * self.head_size,
                          self.num_kv_heads * self.head_size),
                    "v":
                    ((self.num_heads + self.num_kv_heads) * self.head_size,
                     self.num_kv_heads * self.head_size),
                    "total":
                    ((self.num_heads + 2 * self.num_kv_heads) * self.head_size,
                     0)
                }
1147
                shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
1148
1149
                    param, orig_qkv_offsets, loaded_shard_id)

1150
1151
            param_data = param_data.narrow(output_dim, shard_offset,
                                           shard_size)
1152
            if loaded_shard_id == "q":
1153
                shard_id = self.tp_rank
1154
            else:
1155
                shard_id = self.tp_rank // self.num_kv_head_replicas
1156
            start_idx = shard_id * shard_size
1157

1158
            if not is_sharded_weight:
1159
1160
1161
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)

1162
1163
1164
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
1165
                param_data, loaded_weight, loaded_shard_id)
1166
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
1167
1168
1169
1170
1171
1172
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "QKVParallelLinear, assume the weight is the same "
                    "for all partitions.")
1173

1174
1175
1176
1177
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)


1178
@CustomOp.register("row_parallel_linear")
1179
class RowParallelLinear(LinearBase):
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
    """Linear layer with row parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its first dimension and X along its second dimension as:
               -   -
              | A_1 |
              | .   |
          A = | .   |        X = [X_1, ..., X_p]
              | .   |
              | A_p |
               -   -
    Arguments:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias. Note that bias is not parallelized.
        input_is_parallel: If true, we assume that the input is already
                           split across the GPUs and we do not split
                           again.
        skip_bias_add: This was added to enable performance optimization where
                       bias can be fused with other element-wise operations.
                       We skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
1202
1203
1204
        reduce_results: If true, call all-reduce on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y = X_iA_i
1205
        quant_config: Quantization configure.
1206
1207
1208
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.down_proj)
        return_bias: If true, return bias together with outputs in forward pass.
1209
        disable_tp: If true, weights matrix won't be sharded through tp rank.
1210
1211
    """

1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        input_is_parallel: bool = True,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        reduce_results: bool = True,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        *,
        return_bias: bool = True,
1225
        disable_tp: bool = False,
1226
    ):
1227
        # Divide the weight matrix along the first dimension.
1228
1229
1230
1231
        self.tp_rank = (get_tensor_model_parallel_rank()
                        if not disable_tp else 0)
        self.tp_size = (get_tensor_model_parallel_world_size()
                        if not disable_tp else 1)
1232
1233
1234
1235
        self.input_size_per_partition = divide(input_size, self.tp_size)
        self.output_size_per_partition = output_size
        self.output_partition_sizes = [output_size]

1236
1237
1238
1239
1240
1241
        super().__init__(input_size,
                         output_size,
                         skip_bias_add,
                         params_dtype,
                         quant_config,
                         prefix,
1242
1243
                         return_bias=return_bias,
                         disable_tp=disable_tp)
1244

1245
1246
1247
        self.input_is_parallel = input_is_parallel
        self.reduce_results = reduce_results

1248
        assert self.quant_method is not None
1249
1250
1251
        self.quant_method.create_weights(
            layer=self,
            input_size_per_partition=self.input_size_per_partition,
1252
            output_partition_sizes=self.output_partition_sizes,
1253
1254
1255
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
1256
1257
            weight_loader=(
                self.weight_loader_v2 if self.quant_method.__class__.__name__
1258
                in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
1259
1260
1261
1262
1263
1264
        if not reduce_results and (bias and not skip_bias_add):
            raise ValueError("When not reduce the results, adding bias to the "
                             "results can lead to incorrect results")

        if bias:
            self.bias = Parameter(
1265
                torch.empty(self.output_size, dtype=params_dtype))
1266
1267
1268
1269
1270
1271
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
        else:
            self.register_parameter("bias", None)
1272
        self.update_param_tp_status()
1273
1274
1275

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        input_dim = getattr(param, "input_dim", None)
1276
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1277
1278
1279
1280
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
1281
1282
1283
1284
1285
1286
1287
1288
1289

        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
1290
1291
            weight_shape = list(loaded_weight.shape)
            if input_dim:
1292
1293
                weight_shape[input_dim] = (weight_shape[input_dim] //
                                           self.tp_size)
1294
            param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
1295

1296
        param_data = param.data
1297
        if input_dim is not None and not is_sharded_weight:
1298
            shard_size = param_data.shape[input_dim]
1299
            start_idx = self.tp_rank * shard_size
1300
1301
            loaded_weight = loaded_weight.narrow(input_dim, start_idx,
                                                 shard_size)
1302

1303
1304
1305
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
1306
1307
            loaded_weight = loaded_weight.reshape(1)

1308
1309
1310
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

1311
1312
    def weight_loader_v2(self, param: BasevLLMParameter,
                         loaded_weight: torch.Tensor):
1313
1314
1315
1316
1317
1318
1319

        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)

1320
1321
        param.load_row_parallel_weight(loaded_weight=loaded_weight)

1322
1323
1324
    def forward(
        self, input_
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
1325
1326
1327
1328
1329
        if self.input_is_parallel:
            input_parallel = input_
        else:
            splitted_input = split_tensor_along_last_dim(
                input_, num_partitions=self.tp_size)
1330
            input_parallel = splitted_input[self.tp_rank].contiguous()
1331
1332

        # Matrix multiply.
1333
        assert self.quant_method is not None
1334
1335
1336
1337
1338
1339
        # Only fuse bias add into GEMM for rank 0 (this ensures that
        # bias will not get added more than once in TP>1 case)
        bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
        output_parallel = self.quant_method.apply(self,
                                                  input_parallel,
                                                  bias=bias_)
1340
        if self.reduce_results and self.tp_size > 1:
1341
            output = tensor_model_parallel_all_reduce(output_parallel)
1342
        else:
1343
1344
1345
            output = output_parallel

        output_bias = self.bias if self.skip_bias_add else None
1346

1347
1348
        if not self.return_bias:
            return output
1349
        return output, output_bias
1350
1351

    def extra_repr(self) -> str:
1352
        s = f"in_features={self.input_size_per_partition}"
1353
1354
1355
1356
1357
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={self.tp_size}"
        s += f", reduce_results={self.reduce_results}"
        return s
1358
1359


1360
@CustomOp.register("qkv_cross_parallel_linear")
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
class QKVCrossParallelLinear(LinearBase):
    """Linear layers for efficient cross-attention's QKV transformation.

    Args:
        hidden_size: input hidden state size of the transformer.
        head_size: size of each attention head.
        total_num_heads: total number of attention query heads.
        total_num_kv_heads: total number of attention key/value heads. If
                            None, assume total_num_kv_heads = total_num_heads.
        bias: If true, add bias.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
        quant_config: Quantization configure.
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
    """
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389

    def __init__(self,
                 hidden_size: int,
                 head_size: int,
                 total_num_heads: int,
                 total_num_kv_heads: Optional[int] = None,
                 bias: bool = True,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
        # input_size and output_size are not used, just for alignment
        input_size = hidden_size
        output_size = (total_num_heads + (total_num_kv_heads or 0)) * head_size
        super().__init__(input_size=input_size,
                         output_size=output_size,
                         skip_bias_add=skip_bias_add,
                         params_dtype=params_dtype,
                         quant_config=quant_config,
                         prefix=prefix)

        self.quant_config = quant_config

1402
        # Empty placeholders for loading as a single module.
1403
1404
1405
1406
1407
1408
1409
1410
1411
        placeholder_size = 0
        assert self.quant_method is not None
        self.quant_method.create_weights(self,
                                         placeholder_size, [placeholder_size],
                                         placeholder_size,
                                         placeholder_size,
                                         self.params_dtype,
                                         weight_loader=self.weight_loader)

1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
        # Use a dictionary to avoid submodules parameters auto-registration:
        # drop-in replacement for a `QKVParallelLinear` module.
        self.proj = dict()
        self.proj["q_proj_decoder"] = ColumnParallelLinear(
            input_size=hidden_size,
            output_size=total_num_heads * head_size,
            bias=bias,
            quant_config=quant_config,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            prefix=f"{prefix}.q_proj_decoder")

        self.proj["kv_proj_encoder"] = QKVParallelLinear(
            hidden_size=hidden_size,
            head_size=head_size,
            total_num_heads=0,
            total_num_kv_heads=total_num_kv_heads,
            bias=bias,
            quant_config=quant_config,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            prefix=f"{prefix}.kv_proj_encoder")

        # `kv_proj_encoder.num_kv_heads` accounts for sharding with tp>1.
1436
        self.q_size = self.q_proj_decoder.output_size_per_partition
1437
1438
1439
1440
1441
        self.kv_size = self.kv_proj_encoder.num_kv_heads * head_size

        if bias:
            self.bias = torch.nn.Parameter()
            set_weight_attrs(self.bias, {
1442
                "output_dim": 0,
1443
                "weight_loader": self.weight_loader_v1,
1444
            })
1445
1446
        else:
            self.bias = None
1447

1448
1449
1450
1451
1452
    def process_weights_after_loading(self):
        for layer in self.proj.values():
            if self.quant_method is not None:
                self.quant_method.process_weights_after_loading(layer)

1453
    @property
1454
1455
1456
    def q_proj_decoder(self) -> ColumnParallelLinear:
        layer = self.proj["q_proj_decoder"]
        for name, param in self.named_parameters():
1457
1458
1459
1460
1461
            target_param = getattr(layer, name, None)
            if target_param is not None:
                self.sync_weight_attrs(param,
                                       target_param,
                                       mode="q_proj_decoder")
1462
        return layer
1463
1464

    @property
1465
1466
1467
    def kv_proj_encoder(self) -> QKVParallelLinear:
        layer = self.proj["kv_proj_encoder"]
        for name, param in self.named_parameters():
1468
1469
1470
1471
1472
            target_param = getattr(layer, name, None)
            if target_param is not None:
                self.sync_weight_attrs(param,
                                       target_param,
                                       mode="kv_proj_encoder")
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
        return layer

    def sync_weight_attrs(
        self,
        src_param: nn.Parameter,
        tgt_param: nn.Parameter,
        mode: Literal["q_proj_decoder", "kv_proj_encoder"],
    ):
        missing_attrs_dict = {
            k: getattr(src_param, k)
1483
1484
            for k in (set(vars(src_param).keys()) -
                      set(vars(tgt_param).keys()))
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
        }
        # TODO(Isotr0py): handle bitsandbytes 8bit
        use_bitsandbytes_4bit = getattr(src_param, "use_bitsandbytes_4bit",
                                        False)
        if (missing_attrs_dict and use_bitsandbytes_4bit):
            q_proj_attrs, kv_proj_attrs = left_shift_bitsandbytes_4bit_shard(
                missing_attrs_dict)
            if mode == "q_proj_decoder":
                set_weight_attrs(tgt_param, q_proj_attrs)
            elif mode == "kv_proj_encoder":
                set_weight_attrs(tgt_param, kv_proj_attrs)
        else:
            set_weight_attrs(tgt_param, missing_attrs_dict)
1498

1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
    def _is_same_param(
        self,
        src_param: torch.nn.Parameter,
        map_param: torch.nn.Parameter,
    ) -> bool:
        """Check if two parameters are exactly pointing to same things."""
        # ignore weight_loader because it's always different
        key_to_ignore = ["weight_loader", "_weight_loader"]
        has_same_type_name = type(src_param) is type(map_param)
        src_param_attrs = {
            k: v
            for k, v in src_param.__dict__.items() if k not in key_to_ignore
        }
        map_param_attrs = {
            k: v
            for k, v in map_param.__dict__.items() if k not in key_to_ignore
        }
        has_same_attrs = src_param_attrs == map_param_attrs
        return has_same_type_name and has_same_attrs

    def select_proj_params(
        self,
        layer: nn.Module,
        param: nn.Parameter,
    ) -> nn.Parameter:
        """
        Given the placeholder param, 
        return the corresponding param in the proj layers.
        """
        target_param_list = [
            v for _, v in layer.named_parameters()
            if self._is_same_param(param, v)
        ]
        assert len(target_param_list) == 1
        target_param = target_param_list[0]
        return target_param

    def forward(  # type: ignore[override]
        self,
        decoder_hidden_states: torch.Tensor,
        encoder_hidden_states: torch.Tensor,
    ) -> tuple[torch.Tensor, ...]:
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
        q, _ = self.q_proj_decoder(decoder_hidden_states)
        if encoder_hidden_states is None:
            # Encoder KV already cached.
            k = None
            v = None
        else:
            # Prefill phase, encoder KV cached here.
            kv_enc, _ = self.kv_proj_encoder(encoder_hidden_states)
            # Split kv in half
            k, v = kv_enc.split(self.kv_size, dim=-1)
        return q, k, v

1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
    def weight_loader_v1(self,
                         param: torch.nn.Parameter,
                         loaded_weight: torch.Tensor,
                         loaded_shard_id: Optional[str] = None):
        # just like all other parameters, does not yet
        # support loading bias with weight_loader_v2
        layer = (self.q_proj_decoder
                 if loaded_shard_id == "q" else self.kv_proj_encoder)
        target_param = self.select_proj_params(layer, param)
        shard_id_args = (loaded_shard_id, ) if loaded_shard_id != "q" else ()
        layer.weight_loader(target_param, loaded_weight, *shard_id_args)

1565
1566
1567
1568
1569
1570
1571
1572
    def weight_loader(self,
                      param: torch.nn.Parameter,
                      loaded_weight: torch.Tensor,
                      loaded_shard_id: Optional[str] = None):
        layer = (self.q_proj_decoder
                 if loaded_shard_id == "q" else self.kv_proj_encoder)
        target_param = self.select_proj_params(layer, param)
        shard_id_args = (loaded_shard_id, ) if loaded_shard_id != "q" else ()
1573
1574
1575
1576
        if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED:
            layer.weight_loader_v2(target_param, loaded_weight, *shard_id_args)
        else:
            layer.weight_loader(target_param, loaded_weight, *shard_id_args)
1577
1578
1579

    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
1580
        s += f", q_size={self.q_size}"
1581
1582
1583
1584
1585
        s += f", kv_size={self.kv_size}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={get_tensor_model_parallel_world_size()}"
        s += ", gather_output=False"
        return s