linear.py 55.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
from abc import abstractmethod
6
from typing import Any
7
8

import torch
9
from torch.nn.parameter import Parameter, UninitializedParameter
10

11
12
13
14
15
16
17
18
from vllm.distributed import (
    divide,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
    split_tensor_along_last_dim,
    tensor_model_parallel_all_gather,
    tensor_model_parallel_all_reduce,
)
19
from vllm.logger import init_logger
20
from vllm.model_executor.custom_op import CustomOp
21
from vllm.model_executor.layers.quantization.base_config import (
22
23
24
    QuantizationConfig,
    QuantizeMethodBase,
)
25
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
26
27
28
29
30
31
32
33
34
from vllm.model_executor.parameter import (
    BasevLLMParameter,
    BlockQuantScaleParameter,
    ModelWeightParameter,
    PackedColumnParameter,
    PackedvLLMParameter,
    PerTensorScaleParameter,
    RowvLLMParameter,
)
35
from vllm.model_executor.utils import set_weight_attrs
36
from vllm.platforms import current_platform
37
38
39

logger = init_logger(__name__)

40
WEIGHT_LOADER_V2_SUPPORTED = [
41
    "UnquantizedLinearMethod",
42
    "CompressedTensorsLinearMethod",
43
    "CompressedTensorsLinearTransformMethod",
44
45
    "BitBLASLinearMethod",
    "GPTQBitBLASLinearMethod",
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
    "AWQMarlinLinearMethod",
    "AWQLinearMethod",
    "GPTQMarlinLinearMethod",
    "Fp8LinearMethod",
    "MarlinLinearMethod",
    "GPTQMarlin24LinearMethod",
    "TPUInt8LinearMethod",
    "GPTQLinearMethod",
    "FBGEMMFp8LinearMethod",
    "ModelOptFp8LinearMethod",
    "IPEXAWQLinearMethod",
    "IPEXGPTQLinearMethod",
    "HQQMarlinMethod",
    "QuarkLinearMethod",
    "ModelOptNvFp4LinearMethod",
61
    "PetitNvFp4LinearMethod",
62
]
63

64

65
66
67
def adjust_bitblas_shard(param, shard_size, shard_offset):
    bitblas_tile_size = getattr(param, "bitblas_tile_size", None)
    if bitblas_tile_size is not None:
68
        return (shard_size // bitblas_tile_size, shard_offset // bitblas_tile_size)
69
70
71
72

    return shard_size, shard_offset


73
74
75
76
77
78
79
80
def adjust_marlin_shard(param, shard_size, shard_offset):
    marlin_tile_size = getattr(param, "marlin_tile_size", None)
    if marlin_tile_size is None:
        return shard_size, shard_offset

    return shard_size * marlin_tile_size, shard_offset * marlin_tile_size


81
82
83
def adjust_bitsandbytes_4bit_shard(
    param: Parameter, shard_offsets: dict[str, tuple[int, int]], loaded_shard_id: str
) -> tuple[int, int]:
84
85
    """Adjust the quantization offsets and sizes for BitsAndBytes sharding."""

86
87
    total, _ = shard_offsets["total"]
    orig_offset, orig_size = shard_offsets[loaded_shard_id]
88
89
90
91
92
93
94
95

    quantized_total = param.data.shape[0]
    quantized_offset = orig_offset * quantized_total // total
    quantized_size = orig_size * quantized_total // total

    return quantized_size, quantized_offset


96
97
98
def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
    """For fused modules (QKV and MLP) we have an array of length
    N that holds 1 scale for each "logical" matrix. So the param
99
100
    is an array of length N. The loaded_weight corresponds to
    one of the shards on disk. Here, we slice the param based on
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
    the shard_id for loading.
    """
    qkv_idxs = {"q": 0, "k": 1, "v": 2}

    if isinstance(shard_id, str):
        shard_id = qkv_idxs[shard_id]
    elif not isinstance(shard_id, int):
        raise ValueError(f"Unknown Shard Id {shard_id}")

    # AutoFP8 scales do not have a shape
    # compressed-tensors scales do have a shape
    if len(loaded_weight.shape) != 0:
        assert loaded_weight.shape[0] == 1
        loaded_weight = loaded_weight[0]

    return param[shard_id], loaded_weight


119
120
121
122
123
124
125
126
# TODO(Isotr0py): We might need a more flexible structure to handle
# bitsandbytes shard offsets.
def left_shift_bitsandbytes_4bit_shard(bnb_weight_attrs: dict[str, Any]):
    """
    Separate the BitsAndBytes 4-bit shard.

    For example, given bnb weight attributes as below:
    {
127
        'bnb_shard_offsets': array([0, 4, 8, 16]),
128
129
130
131
132
        'bnb_quant_state': {0: ..., 1: ..., 2: ...},
    }

    The function will return:
    {
133
        'bnb_shard_offsets': array([0, 4]),
134
135
136
137
138
139
140
141
142
143
144
145
146
147
        'bnb_quant_state': {0: ...},
    }
    and
    {
        'bnb_shard_offsets': array([0, 4, 12]),
        'bnb_quant_state': {0: ..., 1: ...},
    }
    """
    shard_offsets = bnb_weight_attrs["bnb_shard_offsets"]
    offset_l = shard_offsets[:2]
    offset_r = shard_offsets[1:] - shard_offsets[1]
    quant_state_l = {0: bnb_weight_attrs["bnb_quant_state"][0]}
    quant_state_r = {
        i - 1: bnb_weight_attrs["bnb_quant_state"][i]
148
        for i in range(1, len(shard_offsets) - 1)
149
150
151
152
153
154
    }
    left = dict(bnb_shard_offsets=offset_l, bnb_quant_state=quant_state_l)
    right = dict(bnb_shard_offsets=offset_r, bnb_quant_state=quant_state_r)
    return left, right


155
class LinearMethodBase(QuantizeMethodBase):
156
157
158
    """Base class for different (maybe quantized) linear methods."""

    @abstractmethod
159
160
161
162
163
164
165
166
167
168
169
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        """Create weights for a linear layer.
170
           The weights will be set as attributes of the layer.
171

172
173
174
        Args:
            layer: The layer that is using the LinearMethodBase factory.
            input_size_per_partition: Size of the weight input dim on rank X.
175
            output_partition_sizes: Sizes of the output dim of each logical
176
177
178
179
180
181
                weight on rank X. E.g., output_partition_sizes for QKVLinear
                is a list contains the width of Wq, Wk, Wv on rank X.
            input_size: Size of the input dim of the weight across all ranks.
            output_size: Size of the output dim of the weight across all ranks.
            params_dtype: Datatype of the parameters.
        """
182
183
184
        raise NotImplementedError

    @abstractmethod
185
186
187
188
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
189
        bias: torch.Tensor | None = None,
190
    ) -> torch.Tensor:
191
192
        """Apply the weights in layer to the input tensor.
        Expects create_weights to have been called before on the layer."""
193
194
195
196
        raise NotImplementedError


class UnquantizedLinearMethod(LinearMethodBase):
197
    """Linear method without quantization."""
198

199
200
201
202
203
204
205
206
207
208
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
209
210
211
212
        # This method creates unquantized linear weights.
        # The weights are not quantized, and they are not sharded.
        # The amount of memory allocated for the weights is
        # sum(output_partition_sizes) * input_size_per_partition.
213
214
215
216
217
218
219
220
221
222
223
        weight_loader = extra_weight_attrs.pop("weight_loader")
        weight = ModelWeightParameter(
            data=torch.empty(
                sum(output_partition_sizes),
                input_size_per_partition,
                dtype=params_dtype,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
224

225
226
        layer.register_parameter("weight", weight)
        set_weight_attrs(weight, extra_weight_attrs)
227

228
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
229
        if current_platform.is_cpu():
230
            from vllm.model_executor.layers.utils import dispatch_cpu_unquantized_gemm
231

232
            dispatch_cpu_unquantized_gemm(layer, remove_weight=True)
233

234
235
236
237
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
238
        bias: torch.Tensor | None = None,
239
    ) -> torch.Tensor:
240
        return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
241
242


243
class LinearBase(CustomOp):
244
    """Base linear layer.
245
246
247
248
249
250

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
251
        quant_config: Quantization configure.
252
        prefix: Prefix for parameter names.
253
        return_bias: If true, return bias together with outputs in forward pass.
254
        disable_tp: If true, tensor parallelism will be disabled for this layer.
255
256
257
258
259
260
261
    """

    def __init__(
        self,
        input_size: int,
        output_size: int,
        skip_bias_add: bool = False,
262
263
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
264
        prefix: str = "",
265
266
        *,
        return_bias: bool = True,
267
        disable_tp: bool = False,
268
269
270
271
272
273
274
275
276
277
    ):
        super().__init__()

        # Keep input parameters
        self.input_size = input_size
        self.output_size = output_size
        self.skip_bias_add = skip_bias_add
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype
278
279
        self.quant_config = quant_config
        self.prefix = prefix
280
        if quant_config is None:
281
            self.quant_method: QuantizeMethodBase | None = UnquantizedLinearMethod()
282
        else:
283
            self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
284
        self.return_bias = return_bias
285
        self.disable_tp = disable_tp
286
287
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
288

289
    def update_param_tp_status(self):
290
291
292
293
        for param in self.parameters():
            if isinstance(param, BasevLLMParameter):
                param.tp_rank = self.tp_rank
                param.tp_size = self.tp_size
294
295


296
@CustomOp.register("replicated_linear")
297
298
299
300
301
302
303
304
305
306
class ReplicatedLinear(LinearBase):
    """Replicated linear layer.

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        bias: If true, add bias.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
        quant_config: Quantization configure.
307
308
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
309
        return_bias: If true, return bias together with outputs in forward pass.
310
        disable_tp: Take no effect for replicated linear layers.
311
312
    """

313
314
315
316
317
318
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        skip_bias_add: bool = False,
319
320
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
321
322
323
        prefix: str = "",
        *,
        return_bias: bool = True,
324
        disable_tp: bool = False,
325
    ):
326
327
328
329
330
331
        # If MergedReplicatedLinear, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = self.output_sizes
        else:
            self.output_partition_sizes = [output_size]

332
333
334
335
336
337
338
339
340
341
        super().__init__(
            input_size,
            output_size,
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
342

343
344
        # All the linear layer supports quant method.
        assert self.quant_method is not None
345
346
347
348
349
350
351
352
353
        self.quant_method.create_weights(
            self,
            self.input_size,
            self.output_partition_sizes,
            self.input_size,
            self.output_size,
            self.params_dtype,
            weight_loader=self.weight_loader,
        )
354

355
356
        if bias:
            self.bias = Parameter(
357
358
359
360
361
362
363
364
365
                torch.empty(self.output_size, dtype=self.params_dtype)
            )
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
366
367
368
        else:
            self.register_parameter("bias", None)

369
370
371
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        # If the weight on disk does not have a shape, give it one
        # (such scales for AutoFp8).
372
373
374
375
376
377
378
379
380
381
382
        # Special case for GGUF

        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
            param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)

383
384
385
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)

386
387
        assert param.size() == loaded_weight.size(), (
            f"Tried to load weights of size {loaded_weight.size()}"
388
389
            f"to a parameter of size {param.size()}"
        )
390
391
        param.data.copy_(loaded_weight)

392
    def forward(
393
394
        self,
        x: torch.Tensor,
395
    ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
396
        bias = self.bias if not self.skip_bias_add else None
397
        assert self.quant_method is not None
398

399
        output = self.quant_method.apply(self, x, bias)
400
        output_bias = self.bias if self.skip_bias_add else None
401

402
403
        if not self.return_bias:
            return output
404
405
        return output, output_bias

406
407
408
409
410
411
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        return s

412

413
@CustomOp.register("column_parallel_linear")
414
class ColumnParallelLinear(LinearBase):
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
    """Linear layer with column parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its second dimension as A = [A_1, ..., A_p].

    Args:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y_i = XA_i
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
431
        quant_config: Quantization configure.
James Fleming's avatar
James Fleming committed
432
433
        output_sizes: list of output sizes packed into one output, like for QKV
                       the list would be size 3.
434
        prefix: The name of the layer in the state dict, including all parents
435
                        (e.g. model.layers.0.qkv_proj)
436
437
        return_bias: If true, return bias together with outputs in forward pass.
        disable_tp: If true, weights matrix won't be sharded through tp rank.
438
439
    """

440
441
442
443
444
445
446
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
447
448
449
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
        output_sizes: list[int] | None = None,
450
451
452
        prefix: str = "",
        *,
        return_bias: bool = True,
453
        disable_tp: bool = False,
454
    ):
455
        # Divide the weight matrix along the last dimension.
456
457
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
458
459
        self.input_size_per_partition = input_size
        self.output_size_per_partition = divide(output_size, self.tp_size)
460
461
462
463
        self.output_partition_sizes = [self.output_size_per_partition]
        # If QKV or MergedColumn, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = [
464
                divide(output_size, self.tp_size) for output_size in self.output_sizes
465
466
            ]

467
468
469
470
471
472
473
474
475
476
        super().__init__(
            input_size,
            output_size,
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
477
478
479

        self.gather_output = gather_output

James Fleming's avatar
James Fleming committed
480
481
        if output_sizes is None:
            output_sizes = [output_size]
482

483
        assert self.quant_method is not None
484
485
        self.quant_method.create_weights(
            layer=self,
486
            input_size_per_partition=self.input_size_per_partition,
487
488
489
490
            output_partition_sizes=self.output_partition_sizes,
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
491
            weight_loader=(
492
493
494
495
496
                self.weight_loader_v2
                if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
                else self.weight_loader
            ),
        )
497
498
        if bias:
            self.bias = Parameter(
499
500
501
502
503
504
505
506
507
                torch.empty(self.output_size_per_partition, dtype=params_dtype)
            )
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
508
509
        else:
            self.register_parameter("bias", None)
510
        self.update_param_tp_status()
511
512
513

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        output_dim = getattr(param, "output_dim", None)
514

515
516
517
518
519
520
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

521
522
523
524
525
526
527
528
        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
529
530
            final_shape = list(loaded_weight.shape)
            if output_dim is not None:
531
                assert final_shape[output_dim] % self.tp_size == 0
532
                final_shape[output_dim] = final_shape[output_dim] // self.tp_size
533
            param.materialize(final_shape, dtype=loaded_weight.dtype)
534

535
        param_data = param.data
536
        if output_dim is not None and not is_sharded_weight:
537
            shard_size = param_data.shape[output_dim]
538
            start_idx = self.tp_rank * shard_size
539
            loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
540
541
542
543
544

        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)
545

546
547
548
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

549
    def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor):
550
551
552
553
554
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)
555
556
        param.load_column_parallel_weight(loaded_weight=loaded_weight)

557
    def forward(
558
559
        self,
        input_,
560
    ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
561
562
563
        bias = self.bias if not self.skip_bias_add else None

        # Matrix multiply.
564
        assert self.quant_method is not None
565
        output_parallel = self.quant_method.apply(self, input_, bias)
566

567
        if self.gather_output and self.tp_size > 1:
568
569
570
571
572
            # All-gather across the partitions.
            output = tensor_model_parallel_all_gather(output_parallel)
        else:
            output = output_parallel
        output_bias = self.bias if self.skip_bias_add else None
573
574
        if not self.return_bias:
            return output
575
576
        return output, output_bias

577
578
579
580
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size_per_partition}"
        s += f", bias={self.bias is not None}"
581
        s += f", tp_size={self.tp_size}"
582
583
584
        s += f", gather_output={self.gather_output}"
        return s

585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603

class MergedColumnParallelLinear(ColumnParallelLinear):
    """Packed linear layers with column parallelism.

    Similar to ColumnParallelLinear, but the weight matrix is concatenated
    along the output dimension. When the weight matrix is loaded, the
    different partitions are sharded separately.

    Args:
        input_size: input dimension of the linear layer.
        output_sizes: list of output dimensions of the linear layer.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make the output
                       available to all GPUs, otherwise, every GPU will have
                       its own output.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
604
        quant_config: Quantization configure.
605
606
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
607
        return_bias: If true, return bias together with outputs in forward pass.
608
609
        disable_tp: If true, all weights matrix won't be sharded, this layer
                    will be treated as a "Replicated" MergedLinear.
610
611
    """

612
613
614
615
616
617
618
    def __init__(
        self,
        input_size: int,
        output_sizes: list[int],
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
619
620
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
621
622
623
        prefix: str = "",
        *,
        return_bias: bool = True,
624
        disable_tp: bool = False,
625
    ):
626
        self.output_sizes = output_sizes
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0

        assert all(output_size % self.tp_size == 0 for output_size in output_sizes)
        super().__init__(
            input_size=input_size,
            output_size=sum(output_sizes),
            bias=bias,
            gather_output=gather_output,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            quant_config=quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
James Fleming's avatar
James Fleming committed
643

644
645
646
647
    def weight_loader(
        self,
        param: Parameter,
        loaded_weight: torch.Tensor,
648
        loaded_shard_id: int | None = None,
649
    ):
650
651
652
653
654
        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
655
656
657
658
659
            if loaded_shard_id is not None:
                param.data[loaded_shard_id].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
                param.shard_weight_type = {
660
                    i: loaded_weight.item() for i, _ in enumerate(self.output_sizes)
661
                }
662
663
            return

664
665
        if is_gguf_weight:
            output_dim = getattr(param, "output_dim", None)
666
667
            shard_size = loaded_weight.size(output_dim) // self.tp_size
            start_idx = self.tp_rank * shard_size
668

669
            if loaded_shard_id is not None:
670
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
671
672
673
674
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
675

676
677
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
678
679
        # Special case for per-tensor scale to load scalar into fused array.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
680

681
        if loaded_shard_id is None:
682
683
            # Loaded weight is already fused on disk (mlp).
            # (e.g., Phi-3's gate_up_proj).
684
            if output_dim is None:
685
                if needs_scalar_to_array:
686
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
687
688
                        param_data, loaded_weight, 0
                    )
689

690
691
692
693
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            current_shard_offset = 0
694
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
695
            shard_offsets: list[tuple[int, int, int]] = []
696
697
698
699
700
            for i, output_size in enumerate(self.output_sizes):
                shard_offsets.append((i, current_shard_offset, output_size))
                current_shard_offset += output_size
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
701
                # Special case for Quantization.
702
703
704
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
705
706
                    shard_size = shard_size // param.packed_factor
                    shard_offset = shard_offset // param.packed_factor
707
                    # Special case for Marlin.
708
                    shard_size, shard_offset = adjust_marlin_shard(
709
710
                        param, shard_size, shard_offset
                    )
711

712
                shard_size, shard_offset = adjust_bitblas_shard(
713
714
                    param, shard_size, shard_offset
                )
715

716
                if use_bitsandbytes_4bit:
717
718
719
720
721
722
723
                    index = list(itertools.accumulate([0] + self.output_sizes))
                    orig_offsets = {
                        str(i): (index[i], size)
                        for i, size in enumerate(self.output_sizes)
                    }
                    orig_offsets["total"] = (self.output_size, 0)
                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
724
725
                        param, orig_offsets, str(shard_id)
                    )
726

727
                loaded_weight_shard = loaded_weight.narrow(
728
729
                    output_dim, shard_offset, shard_size
                )
730
731
732
733
734
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id < len(self.output_sizes)
        if output_dim is not None:
735
            shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
736
            shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
737
            # Special case for quantization.
738
739
740
741
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
742
743
                shard_size = shard_size // param.packed_factor
                shard_offset = shard_offset // param.packed_factor
744
                # Special case for Marlin.
745
                shard_size, shard_offset = adjust_marlin_shard(
746
747
                    param, shard_size, shard_offset
                )
748
            shard_size, shard_offset = adjust_bitblas_shard(
749
750
                param, shard_size, shard_offset
            )
751

752
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
753
754
755
756
757
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

758
            if use_bitsandbytes_4bit:
759
                shard_size = loaded_weight.shape[output_dim]
760
                shard_offset = loaded_weight.shape[output_dim] * loaded_shard_id
761

762
            param_data = param_data.narrow(output_dim, shard_offset, shard_size)
763
            start_idx = self.tp_rank * shard_size
764
            if not is_sharded_weight:
765
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
766
767
768
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
769
770
                param_data, loaded_weight, loaded_shard_id
            )
771

772
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
773
774
775
776
777
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "MergedColumnParallelLinear, assume the weight is "
778
779
                    "the same for all partitions."
                )
780

781
782
783
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

784
785
786
    def _load_fused_module_from_checkpoint(
        self, param: BasevLLMParameter, loaded_weight: torch.Tensor
    ):
787
788
789
        """
        Handle special case for models where MLP layers are already
        fused on disk. In this case, we have no shard id. This function
790
        determines the shard id by splitting these layers and then calls
791
792
793
794
795
796
797
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """

        current_shard_offset = 0
798
        shard_offsets: list[tuple[int, int, int]] = []
799
800
801
802
803
804
805
806
        for i, output_size in enumerate(self.output_sizes):
            shard_offsets.append((i, current_shard_offset, output_size))
            current_shard_offset += output_size

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
807
808
809
810
811
812
813
814
815
816
817
            if (
                isinstance(param, (PackedColumnParameter, PackedvLLMParameter))
                and param.packed_dim == param.output_dim
            ):
                shard_size, shard_offset = param.adjust_shard_indexes_for_packing(
                    shard_size=shard_size, shard_offset=shard_offset
                )

            loaded_weight_shard = loaded_weight.narrow(
                param.output_dim, shard_offset, shard_size
            )
818
819
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

820
821
822
823
    def weight_loader_v2(
        self,
        param: BasevLLMParameter,
        loaded_weight: torch.Tensor,
824
        loaded_shard_id: int | None = None,
825
    ):
826
        if loaded_shard_id is None:
827
            if isinstance(param, PerTensorScaleParameter):
828
                param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0)
829
                return
830
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
831
                param.load_merged_column_weight(loaded_weight=loaded_weight)
832
                return
833
            # TODO: @dsikka - move to parameter.py
834
835
836
837
838
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id < len(self.output_sizes)

839
840
        if isinstance(param, BlockQuantScaleParameter):
            assert self.quant_method is not None
841
842
843
            # Assume the weight block size has been set by quant method
            assert hasattr(self, "weight_block_size")
            weight_block_size = self.weight_block_size
844
845
846
            assert weight_block_size is not None
            block_n, _ = weight_block_size[0], weight_block_size[1]
            shard_offset = (
847
848
849
850
851
852
853
                (sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) // block_n
            ) // self.tp_size
            shard_size = (
                (self.output_sizes[loaded_shard_id] + block_n - 1)
                // block_n
                // self.tp_size
            )
854
        else:
855
            shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
856
            shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
857

858
859
860
861
862
863
864
        param.load_merged_column_weight(
            loaded_weight=loaded_weight,
            shard_id=loaded_shard_id,
            shard_offset=shard_offset,
            shard_size=shard_size,
            tp_rank=self.tp_rank,
        )
865

866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887

class QKVParallelLinear(ColumnParallelLinear):
    """Linear layers for the attention's QKV transformation.

    Linear layers for the linear transformation of the query, key, and value
    vectors in the attention layer. The weight matrix is concatenated along
    the output dimension. The layer is parallelized along the head dimension.
    When the number of key/value heads is smaller than the number of query
    heads (e.g., multi-query/grouped-query attention), the key/value head may
    be replicated while the query heads are partitioned.

    Args:
        hidden_size: input hidden state size of the transformer.
        head_size: size of each attention head.
        total_num_heads: total number of attention query heads.
        total_num_kv_heads: total number of attention key/value heads. If
                            None, assume total_num_kv_heads = total_num_heads.
        bias: If true, add bias.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
888
        quant_config: Quantization configure.
889
890
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
891
        return_bias: If true, return bias together with outputs in forward pass.
892
        disable_tp: If true, weights matrix won't be sharded through tp rank.
893
894
    """

895
896
897
898
899
    def __init__(
        self,
        hidden_size: int,
        head_size: int,
        total_num_heads: int,
900
        total_num_kv_heads: int | None = None,
901
902
        bias: bool = True,
        skip_bias_add: bool = False,
903
904
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
905
906
907
        prefix: str = "",
        *,
        return_bias: bool = True,
908
        disable_tp: bool = False,
909
    ):
910
911
912
913
914
915
916
        self.hidden_size = hidden_size
        self.head_size = head_size
        self.total_num_heads = total_num_heads
        if total_num_kv_heads is None:
            total_num_kv_heads = total_num_heads
        self.total_num_kv_heads = total_num_kv_heads
        # Divide the weight matrix along the last dimension.
917
        tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
918
919
920
        self.num_heads = divide(self.total_num_heads, tp_size)
        if tp_size >= self.total_num_kv_heads:
            self.num_kv_heads = 1
921
            self.num_kv_head_replicas = divide(tp_size, self.total_num_kv_heads)
922
923
924
925
        else:
            self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
            self.num_kv_head_replicas = 1
        input_size = self.hidden_size
926
927
928
        output_size = (
            (self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size
        )
929
930
931
        self.output_sizes = [
            self.num_heads * self.head_size * tp_size,  # q_proj
            self.num_kv_heads * self.head_size * tp_size,  # k_proj
932
            self.num_kv_heads * self.head_size * tp_size,  # v_proj
James Fleming's avatar
James Fleming committed
933
934
        ]

935
936
937
938
939
940
941
942
943
944
945
946
        super().__init__(
            input_size=input_size,
            output_size=output_size,
            bias=bias,
            gather_output=False,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            quant_config=quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
947

948
949
950
951
952
    def _get_shard_offset_mapping(self, loaded_shard_id: str):
        shard_offset_mapping = {
            "q": 0,
            "k": self.num_heads * self.head_size,
            "v": (self.num_heads + self.num_kv_heads) * self.head_size,
953
            "total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size,
954
955
956
957
958
959
960
961
962
963
964
        }
        return shard_offset_mapping.get(loaded_shard_id)

    def _get_shard_size_mapping(self, loaded_shard_id: str):
        shard_size_mapping = {
            "q": self.num_heads * self.head_size,
            "k": self.num_kv_heads * self.head_size,
            "v": self.num_kv_heads * self.head_size,
        }
        return shard_size_mapping.get(loaded_shard_id)

965
966
967
    def _load_fused_module_from_checkpoint(
        self, param: BasevLLMParameter, loaded_weight: torch.Tensor
    ):
968
        """
969
        Handle special case for models where QKV layers are already
970
        fused on disk. In this case, we have no shard id. This function
971
        determines the shard id by splitting these layers and then calls
972
973
974
975
976
977
978
979
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """
        shard_offsets = [
            # (shard_id, shard_offset, shard_size)
            ("q", 0, self.total_num_heads * self.head_size),
980
981
982
983
984
985
986
987
988
989
            (
                "k",
                self.total_num_heads * self.head_size,
                self.total_num_kv_heads * self.head_size,
            ),
            (
                "v",
                (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
                self.total_num_kv_heads * self.head_size,
            ),
990
991
992
993
994
995
        ]

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
            if (
                isinstance(param, (PackedColumnParameter, PackedvLLMParameter))
                and param.packed_dim == param.output_dim
            ):
                shard_size, shard_offset = param.adjust_shard_indexes_for_packing(
                    shard_size=shard_size, shard_offset=shard_offset
                )

            loaded_weight_shard = loaded_weight.narrow(
                param.output_dim, shard_offset, shard_size
            )
1007
1008
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

1009
1010
1011
1012
    def weight_loader_v2(
        self,
        param: BasevLLMParameter,
        loaded_weight: torch.Tensor,
1013
        loaded_shard_id: str | None = None,
1014
    ):
1015
        if loaded_shard_id is None:  # special case for certain models
1016
            if isinstance(param, PerTensorScaleParameter):
1017
1018
1019
                param.load_qkv_weight(
                    loaded_weight=loaded_weight, shard_id=0, tp_rank=self.tp_rank
                )
1020
                return
1021
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
1022
                param.load_qkv_weight(loaded_weight=loaded_weight, tp_rank=self.tp_rank)
1023
                return
1024
            # TODO: @dsikka - move to parameter.py
1025
1026
1027
1028
1029
1030
1031
1032
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id in ["q", "k", "v"]

        shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
        shard_size = self._get_shard_size_mapping(loaded_shard_id)

1033
1034
1035
        # Note(simon): This is needed for Qwen3's fp8 quantization.
        if isinstance(param, BlockQuantScaleParameter):
            assert self.quant_method is not None
1036
1037
1038
1039
            # Assume the weight block size has been set by quant method
            assert hasattr(self, "weight_block_size")
            weight_block_size = self.weight_block_size
            assert weight_block_size is not None
1040
1041
1042
1043
            block_n, _ = weight_block_size[0], weight_block_size[1]
            shard_offset = (shard_offset + block_n - 1) // block_n
            shard_size = (shard_size + block_n - 1) // block_n

1044
1045
1046
1047
1048
1049
1050
1051
        param.load_qkv_weight(
            loaded_weight=loaded_weight,
            num_heads=self.num_kv_head_replicas,
            shard_id=loaded_shard_id,
            shard_offset=shard_offset,
            shard_size=shard_size,
            tp_rank=self.tp_rank,
        )
1052

1053
1054
1055
1056
    def weight_loader(
        self,
        param: Parameter,
        loaded_weight: torch.Tensor,
1057
        loaded_shard_id: str | None = None,
1058
    ):
1059
1060
1061
1062
        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
1063
        if is_gguf_weight_type:
1064
            idx_map = {"q": 0, "k": 1, "v": 2}
1065
1066
1067
1068
            if loaded_shard_id is not None:
                param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
1069
                param.shard_weight_type = {k: loaded_weight.item() for k in idx_map}
1070
1071
            return

1072
1073
        if is_gguf_weight:
            output_dim = getattr(param, "output_dim", None)
1074
1075
            shard_size = loaded_weight.size(output_dim) // self.tp_size
            start_idx = self.tp_rank * shard_size
1076

1077
            if loaded_shard_id is not None:
1078
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
1079
1080
1081
1082
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
1083

1084
1085
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
1086

1087
1088
        # Special case for per-tensor scales in fused case.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
1089

1090
        if loaded_shard_id is None:
1091
1092
            # Loaded weight is already fused on disk (qkv).
            # (e.g., Phi-3's qkv_proj).
1093
            if output_dim is None:
1094
                if needs_scalar_to_array:
1095
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
1096
1097
                        param_data, loaded_weight, 0
                    )
1098

1099
1100
1101
1102
1103
1104
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            shard_offsets = [
                # (shard_id, shard_offset, shard_size)
                ("q", 0, self.total_num_heads * self.head_size),
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
                (
                    "k",
                    self.total_num_heads * self.head_size,
                    self.total_num_kv_heads * self.head_size,
                ),
                (
                    "v",
                    (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
                    self.total_num_kv_heads * self.head_size,
                ),
1115
            ]
1116
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1117

1118
1119
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
1120
                # Special case for Quantized Weights.
1121
1122
1123
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
1124
1125
                    shard_size = shard_size // param.packed_factor
                    shard_offset = shard_offset // param.packed_factor
1126

1127
                    # Special case for Marlin.
1128
                    shard_size, shard_offset = adjust_marlin_shard(
1129
1130
                        param, shard_size, shard_offset
                    )
1131

1132
1133
1134
                if use_bitsandbytes_4bit:
                    orig_qkv_offsets = {
                        "q": (0, self.total_num_heads * self.head_size),
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
                        "k": (
                            self.total_num_heads * self.head_size,
                            self.total_num_kv_heads * self.head_size,
                        ),
                        "v": (
                            (self.total_num_heads + self.total_num_kv_heads)
                            * self.head_size,
                            self.total_num_kv_heads * self.head_size,
                        ),
                        "total": (
                            (self.total_num_heads + 2 * self.total_num_kv_heads)
                            * self.head_size,
                            0,
                        ),
1149
1150
1151
                    }

                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
1152
1153
                        param, orig_qkv_offsets, shard_id
                    )
1154

1155
                loaded_weight_shard = loaded_weight.narrow(
1156
1157
                    output_dim, shard_offset, shard_size
                )
1158
1159
1160
1161
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id in ["q", "k", "v"]
1162
1163

        # If output dim is defined, use the default loading process.
1164
1165
1166
1167
1168
1169
1170
1171
        if output_dim is not None:
            if loaded_shard_id == "q":
                shard_offset = 0
                shard_size = self.num_heads * self.head_size
            elif loaded_shard_id == "k":
                shard_offset = self.num_heads * self.head_size
                shard_size = self.num_kv_heads * self.head_size
            elif loaded_shard_id == "v":
1172
                shard_offset = (self.num_heads + self.num_kv_heads) * self.head_size
1173
                shard_size = self.num_kv_heads * self.head_size
1174
            # Special case for Quantized Weights.
1175
1176
1177
1178
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
1179
1180
                shard_size = shard_size // param.packed_factor
                shard_offset = shard_offset // param.packed_factor
1181

1182
                # Special case for Marlin.
1183
                shard_size, shard_offset = adjust_marlin_shard(
1184
1185
                    param, shard_size, shard_offset
                )
1186

1187
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1188
1189
1190
1191
1192
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

1193
            if use_bitsandbytes_4bit:
1194
1195
                orig_qkv_offsets = {
                    "q": (0, self.num_heads * self.head_size),
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
                    "k": (
                        self.num_heads * self.head_size,
                        self.num_kv_heads * self.head_size,
                    ),
                    "v": (
                        (self.num_heads + self.num_kv_heads) * self.head_size,
                        self.num_kv_heads * self.head_size,
                    ),
                    "total": (
                        (self.num_heads + 2 * self.num_kv_heads) * self.head_size,
                        0,
                    ),
1208
                }
1209
                shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
1210
1211
                    param, orig_qkv_offsets, loaded_shard_id
                )
1212

1213
            param_data = param_data.narrow(output_dim, shard_offset, shard_size)
1214
            if loaded_shard_id == "q":
1215
                shard_rank = self.tp_rank
1216
            else:
1217
1218
                shard_rank = self.tp_rank // self.num_kv_head_replicas
            start_idx = shard_rank * shard_size
1219

1220
            if not is_sharded_weight:
1221
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
1222

1223
1224
1225
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
1226
1227
                param_data, loaded_weight, loaded_shard_id
            )
1228
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
1229
1230
1231
1232
1233
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "QKVParallelLinear, assume the weight is the same "
1234
1235
                    "for all partitions."
                )
1236

1237
1238
1239
1240
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)


1241
@CustomOp.register("row_parallel_linear")
1242
class RowParallelLinear(LinearBase):
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
    """Linear layer with row parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its first dimension and X along its second dimension as:
               -   -
              | A_1 |
              | .   |
          A = | .   |        X = [X_1, ..., X_p]
              | .   |
              | A_p |
               -   -
    Arguments:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias. Note that bias is not parallelized.
        input_is_parallel: If true, we assume that the input is already
                           split across the GPUs and we do not split
                           again.
        skip_bias_add: This was added to enable performance optimization where
                       bias can be fused with other element-wise operations.
                       We skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
1265
1266
1267
        reduce_results: If true, call all-reduce on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y = X_iA_i
1268
        quant_config: Quantization configure.
1269
1270
1271
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.down_proj)
        return_bias: If true, return bias together with outputs in forward pass.
1272
        disable_tp: If true, weights matrix won't be sharded through tp rank.
1273
1274
    """

1275
1276
1277
1278
1279
1280
1281
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        input_is_parallel: bool = True,
        skip_bias_add: bool = False,
1282
        params_dtype: torch.dtype | None = None,
1283
        reduce_results: bool = True,
1284
        quant_config: QuantizationConfig | None = None,
1285
1286
1287
        prefix: str = "",
        *,
        return_bias: bool = True,
1288
        disable_tp: bool = False,
1289
    ):
1290
        # Divide the weight matrix along the first dimension.
1291
1292
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
1293
1294
1295
1296
        self.input_size_per_partition = divide(input_size, self.tp_size)
        self.output_size_per_partition = output_size
        self.output_partition_sizes = [output_size]

1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
        super().__init__(
            input_size,
            output_size,
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
1307

1308
1309
1310
        self.input_is_parallel = input_is_parallel
        self.reduce_results = reduce_results

1311
        assert self.quant_method is not None
1312
1313
1314
        self.quant_method.create_weights(
            layer=self,
            input_size_per_partition=self.input_size_per_partition,
1315
            output_partition_sizes=self.output_partition_sizes,
1316
1317
1318
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
1319
            weight_loader=(
1320
1321
1322
1323
1324
                self.weight_loader_v2
                if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
                else self.weight_loader
            ),
        )
1325
        if not reduce_results and (bias and not skip_bias_add):
1326
1327
1328
1329
            raise ValueError(
                "When not reduce the results, adding bias to the "
                "results can lead to incorrect results"
            )
1330
1331

        if bias:
1332
1333
1334
1335
1336
1337
1338
1339
            self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype))
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
1340
1341
        else:
            self.register_parameter("bias", None)
1342
        self.update_param_tp_status()
1343
1344
1345

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        input_dim = getattr(param, "input_dim", None)
1346
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1347
1348
1349
1350
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
1351
1352
1353
1354
1355
1356
1357
1358
1359

        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
1360
1361
            weight_shape = list(loaded_weight.shape)
            if input_dim:
1362
                weight_shape[input_dim] = weight_shape[input_dim] // self.tp_size
1363
            param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
1364

1365
        param_data = param.data
1366
        if input_dim is not None and not is_sharded_weight:
1367
            shard_size = param_data.shape[input_dim]
1368
            start_idx = self.tp_rank * shard_size
1369
            loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)
1370

1371
1372
1373
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
1374
1375
            loaded_weight = loaded_weight.reshape(1)

1376
1377
1378
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

1379
    def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor):
1380
1381
1382
1383
1384
1385
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)

1386
1387
        param.load_row_parallel_weight(loaded_weight=loaded_weight)

1388
    def forward(
1389
1390
        self,
        input_,
1391
    ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
1392
1393
1394
1395
        if self.input_is_parallel:
            input_parallel = input_
        else:
            splitted_input = split_tensor_along_last_dim(
1396
1397
                input_, num_partitions=self.tp_size
            )
1398
            input_parallel = splitted_input[self.tp_rank].contiguous()
1399
1400

        # Matrix multiply.
1401
        assert self.quant_method is not None
1402
1403
1404
        # Only fuse bias add into GEMM for rank 0 (this ensures that
        # bias will not get added more than once in TP>1 case)
        bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
1405
1406
        output_parallel = self.quant_method.apply(self, input_parallel, bias_)

1407
        if self.reduce_results and self.tp_size > 1:
1408
            output = tensor_model_parallel_all_reduce(output_parallel)
1409
        else:
1410
1411
1412
            output = output_parallel

        output_bias = self.bias if self.skip_bias_add else None
1413

1414
1415
        if not self.return_bias:
            return output
1416
        return output, output_bias
1417
1418

    def extra_repr(self) -> str:
1419
        s = f"in_features={self.input_size_per_partition}"
1420
1421
1422
1423
1424
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={self.tp_size}"
        s += f", reduce_results={self.reduce_results}"
        return s