linear.py 56.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
from abc import abstractmethod
6
from typing import Any
7
8

import torch
9
from torch.nn.parameter import Parameter, UninitializedParameter
10

11
12
13
14
15
16
17
18
from vllm.distributed import (
    divide,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
    split_tensor_along_last_dim,
    tensor_model_parallel_all_gather,
    tensor_model_parallel_all_reduce,
)
19
from vllm.logger import init_logger
20
from vllm.model_executor.custom_op import CustomOp
21
from vllm.model_executor.layers.quantization.base_config import (
22
23
24
    QuantizationConfig,
    QuantizeMethodBase,
)
25
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
26
27
28
29
30
31
32
33
34
from vllm.model_executor.parameter import (
    BasevLLMParameter,
    BlockQuantScaleParameter,
    ModelWeightParameter,
    PackedColumnParameter,
    PackedvLLMParameter,
    PerTensorScaleParameter,
    RowvLLMParameter,
)
35
from vllm.model_executor.utils import set_weight_attrs
36
from vllm.platforms import current_platform
37
38
39

logger = init_logger(__name__)

40
WEIGHT_LOADER_V2_SUPPORTED = [
41
    "UnquantizedLinearMethod",
42
    "CompressedTensorsLinearMethod",
43
    "CompressedTensorsLinearTransformMethod",
44
45
    "BitBLASLinearMethod",
    "GPTQBitBLASLinearMethod",
46
47
48
49
50
51
52
53
54
55
    "AWQMarlinLinearMethod",
    "AWQLinearMethod",
    "GPTQMarlinLinearMethod",
    "Fp8LinearMethod",
    "MarlinLinearMethod",
    "GPTQMarlin24LinearMethod",
    "TPUInt8LinearMethod",
    "GPTQLinearMethod",
    "FBGEMMFp8LinearMethod",
    "ModelOptFp8LinearMethod",
56
57
    "ModelOptFp8PcPtLinearMethod",
    "ModelOptFp8PbWoLinearMethod",
58
59
60
61
62
    "IPEXAWQLinearMethod",
    "IPEXGPTQLinearMethod",
    "HQQMarlinMethod",
    "QuarkLinearMethod",
    "ModelOptNvFp4LinearMethod",
63
    "PetitNvFp4LinearMethod",
64
]
65

66

67
68
69
def adjust_bitblas_shard(param, shard_size, shard_offset):
    bitblas_tile_size = getattr(param, "bitblas_tile_size", None)
    if bitblas_tile_size is not None:
70
        return (shard_size // bitblas_tile_size, shard_offset // bitblas_tile_size)
71
72
73
74

    return shard_size, shard_offset


75
76
77
78
79
80
81
82
def adjust_marlin_shard(param, shard_size, shard_offset):
    marlin_tile_size = getattr(param, "marlin_tile_size", None)
    if marlin_tile_size is None:
        return shard_size, shard_offset

    return shard_size * marlin_tile_size, shard_offset * marlin_tile_size


83
84
85
def adjust_bitsandbytes_4bit_shard(
    param: Parameter, shard_offsets: dict[str, tuple[int, int]], loaded_shard_id: str
) -> tuple[int, int]:
86
87
    """Adjust the quantization offsets and sizes for BitsAndBytes sharding."""

88
89
    total, _ = shard_offsets["total"]
    orig_offset, orig_size = shard_offsets[loaded_shard_id]
90
91
92
93
94
95
96
97

    quantized_total = param.data.shape[0]
    quantized_offset = orig_offset * quantized_total // total
    quantized_size = orig_size * quantized_total // total

    return quantized_size, quantized_offset


98
99
100
def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
    """For fused modules (QKV and MLP) we have an array of length
    N that holds 1 scale for each "logical" matrix. So the param
101
102
    is an array of length N. The loaded_weight corresponds to
    one of the shards on disk. Here, we slice the param based on
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
    the shard_id for loading.
    """
    qkv_idxs = {"q": 0, "k": 1, "v": 2}

    if isinstance(shard_id, str):
        shard_id = qkv_idxs[shard_id]
    elif not isinstance(shard_id, int):
        raise ValueError(f"Unknown Shard Id {shard_id}")

    # AutoFP8 scales do not have a shape
    # compressed-tensors scales do have a shape
    if len(loaded_weight.shape) != 0:
        assert loaded_weight.shape[0] == 1
        loaded_weight = loaded_weight[0]

    return param[shard_id], loaded_weight


121
122
123
124
125
126
127
128
# TODO(Isotr0py): We might need a more flexible structure to handle
# bitsandbytes shard offsets.
def left_shift_bitsandbytes_4bit_shard(bnb_weight_attrs: dict[str, Any]):
    """
    Separate the BitsAndBytes 4-bit shard.

    For example, given bnb weight attributes as below:
    {
129
        'bnb_shard_offsets': array([0, 4, 8, 16]),
130
131
132
133
134
        'bnb_quant_state': {0: ..., 1: ..., 2: ...},
    }

    The function will return:
    {
135
        'bnb_shard_offsets': array([0, 4]),
136
137
138
139
140
141
142
143
144
145
146
147
148
149
        'bnb_quant_state': {0: ...},
    }
    and
    {
        'bnb_shard_offsets': array([0, 4, 12]),
        'bnb_quant_state': {0: ..., 1: ...},
    }
    """
    shard_offsets = bnb_weight_attrs["bnb_shard_offsets"]
    offset_l = shard_offsets[:2]
    offset_r = shard_offsets[1:] - shard_offsets[1]
    quant_state_l = {0: bnb_weight_attrs["bnb_quant_state"][0]}
    quant_state_r = {
        i - 1: bnb_weight_attrs["bnb_quant_state"][i]
150
        for i in range(1, len(shard_offsets) - 1)
151
152
153
154
155
156
    }
    left = dict(bnb_shard_offsets=offset_l, bnb_quant_state=quant_state_l)
    right = dict(bnb_shard_offsets=offset_r, bnb_quant_state=quant_state_r)
    return left, right


157
class LinearMethodBase(QuantizeMethodBase):
158
159
160
    """Base class for different (maybe quantized) linear methods."""

    @abstractmethod
161
162
163
164
165
166
167
168
169
170
171
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        """Create weights for a linear layer.
172
           The weights will be set as attributes of the layer.
173

174
175
176
        Args:
            layer: The layer that is using the LinearMethodBase factory.
            input_size_per_partition: Size of the weight input dim on rank X.
177
            output_partition_sizes: Sizes of the output dim of each logical
178
179
180
181
182
183
                weight on rank X. E.g., output_partition_sizes for QKVLinear
                is a list contains the width of Wq, Wk, Wv on rank X.
            input_size: Size of the input dim of the weight across all ranks.
            output_size: Size of the output dim of the weight across all ranks.
            params_dtype: Datatype of the parameters.
        """
184
185
186
        raise NotImplementedError

    @abstractmethod
187
188
189
190
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
191
        bias: torch.Tensor | None = None,
192
    ) -> torch.Tensor:
193
194
        """Apply the weights in layer to the input tensor.
        Expects create_weights to have been called before on the layer."""
195
196
197
198
        raise NotImplementedError


class UnquantizedLinearMethod(LinearMethodBase):
199
    """Linear method without quantization."""
200

201
202
203
204
205
206
207
208
209
210
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
211
212
213
214
        # This method creates unquantized linear weights.
        # The weights are not quantized, and they are not sharded.
        # The amount of memory allocated for the weights is
        # sum(output_partition_sizes) * input_size_per_partition.
215
216
217
218
219
220
221
222
223
224
225
        weight_loader = extra_weight_attrs.pop("weight_loader")
        weight = ModelWeightParameter(
            data=torch.empty(
                sum(output_partition_sizes),
                input_size_per_partition,
                dtype=params_dtype,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
226

227
228
        layer.register_parameter("weight", weight)
        set_weight_attrs(weight, extra_weight_attrs)
229

230
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
231
        if current_platform.is_cpu():
232
            from vllm.model_executor.layers.utils import dispatch_cpu_unquantized_gemm
233

234
            dispatch_cpu_unquantized_gemm(layer, remove_weight=True)
235

236
237
238
239
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
240
        bias: torch.Tensor | None = None,
241
    ) -> torch.Tensor:
242
        return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
243
244


245
class LinearBase(CustomOp):
246
    """Base linear layer.
247
248
249
250
251
252

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
253
        quant_config: Quantization configure.
254
        prefix: Prefix for parameter names.
255
        return_bias: If true, return bias together with outputs in forward pass.
256
        disable_tp: If true, tensor parallelism will be disabled for this layer.
257
258
259
260
261
262
263
    """

    def __init__(
        self,
        input_size: int,
        output_size: int,
        skip_bias_add: bool = False,
264
265
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
266
        prefix: str = "",
267
268
        *,
        return_bias: bool = True,
269
        disable_tp: bool = False,
270
271
272
273
274
275
276
277
278
279
    ):
        super().__init__()

        # Keep input parameters
        self.input_size = input_size
        self.output_size = output_size
        self.skip_bias_add = skip_bias_add
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype
280
281
        self.quant_config = quant_config
        self.prefix = prefix
282
        self.allow_fp8_block_shape_mismatch = False
283
        if quant_config is None:
284
            self.quant_method: QuantizeMethodBase | None = UnquantizedLinearMethod()
285
        else:
286
            self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
287
        self.return_bias = return_bias
288
        self.disable_tp = disable_tp
289
290
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
291

292
    def update_param_tp_status(self):
293
294
295
296
        for param in self.parameters():
            if isinstance(param, BasevLLMParameter):
                param.tp_rank = self.tp_rank
                param.tp_size = self.tp_size
297
298


299
@CustomOp.register("replicated_linear")
300
301
302
303
304
305
306
307
308
309
class ReplicatedLinear(LinearBase):
    """Replicated linear layer.

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        bias: If true, add bias.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
        quant_config: Quantization configure.
310
311
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
312
        return_bias: If true, return bias together with outputs in forward pass.
313
        disable_tp: Take no effect for replicated linear layers.
314
315
    """

316
317
318
319
320
321
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        skip_bias_add: bool = False,
322
323
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
324
325
326
        prefix: str = "",
        *,
        return_bias: bool = True,
327
        disable_tp: bool = False,
328
    ):
329
330
331
332
333
334
        # If MergedReplicatedLinear, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = self.output_sizes
        else:
            self.output_partition_sizes = [output_size]

335
336
337
338
339
340
341
342
343
344
        super().__init__(
            input_size,
            output_size,
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
345

346
347
        # All the linear layer supports quant method.
        assert self.quant_method is not None
348
349
350
351
352
353
354
355
356
        self.quant_method.create_weights(
            self,
            self.input_size,
            self.output_partition_sizes,
            self.input_size,
            self.output_size,
            self.params_dtype,
            weight_loader=self.weight_loader,
        )
357

358
359
        if bias:
            self.bias = Parameter(
360
361
362
363
364
365
366
367
368
                torch.empty(self.output_size, dtype=self.params_dtype)
            )
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
369
370
371
        else:
            self.register_parameter("bias", None)

372
373
374
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        # If the weight on disk does not have a shape, give it one
        # (such scales for AutoFp8).
375
376
377
378
379
380
381
382
383
384
385
        # Special case for GGUF

        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
            param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)

386
387
388
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)

389
390
        assert param.size() == loaded_weight.size(), (
            f"Tried to load weights of size {loaded_weight.size()}"
391
392
            f"to a parameter of size {param.size()}"
        )
393
394
        param.data.copy_(loaded_weight)

395
    def forward(
396
397
        self,
        x: torch.Tensor,
398
    ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
399
        bias = self.bias if not self.skip_bias_add else None
400
        assert self.quant_method is not None
401

402
        output = self.quant_method.apply(self, x, bias)
403
        output_bias = self.bias if self.skip_bias_add else None
404

405
406
        if not self.return_bias:
            return output
407
408
        return output, output_bias

409
410
411
412
413
414
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        return s

415

416
@CustomOp.register("column_parallel_linear")
417
class ColumnParallelLinear(LinearBase):
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
    """Linear layer with column parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its second dimension as A = [A_1, ..., A_p].

    Args:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y_i = XA_i
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
434
        quant_config: Quantization configure.
James Fleming's avatar
James Fleming committed
435
436
        output_sizes: list of output sizes packed into one output, like for QKV
                       the list would be size 3.
437
        prefix: The name of the layer in the state dict, including all parents
438
                        (e.g. model.layers.0.qkv_proj)
439
440
        return_bias: If true, return bias together with outputs in forward pass.
        disable_tp: If true, weights matrix won't be sharded through tp rank.
441
442
    """

443
444
445
446
447
448
449
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
450
451
452
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
        output_sizes: list[int] | None = None,
453
454
455
        prefix: str = "",
        *,
        return_bias: bool = True,
456
        disable_tp: bool = False,
457
    ):
458
        # Divide the weight matrix along the last dimension.
459
460
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
461
462
        self.input_size_per_partition = input_size
        self.output_size_per_partition = divide(output_size, self.tp_size)
463
464
465
466
        self.output_partition_sizes = [self.output_size_per_partition]
        # If QKV or MergedColumn, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = [
467
                divide(output_size, self.tp_size) for output_size in self.output_sizes
468
469
            ]

470
471
472
473
474
475
476
477
478
479
        super().__init__(
            input_size,
            output_size,
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
480

481
        self._maybe_allow_fp8_block_shape_mismatch()
482
483
        self.gather_output = gather_output

James Fleming's avatar
James Fleming committed
484
485
        if output_sizes is None:
            output_sizes = [output_size]
486

487
        assert self.quant_method is not None
488
489
        self.quant_method.create_weights(
            layer=self,
490
            input_size_per_partition=self.input_size_per_partition,
491
492
493
494
            output_partition_sizes=self.output_partition_sizes,
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
495
            weight_loader=(
496
497
498
499
500
                self.weight_loader_v2
                if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
                else self.weight_loader
            ),
        )
501
502
        if bias:
            self.bias = Parameter(
503
504
505
506
507
508
509
510
511
                torch.empty(self.output_size_per_partition, dtype=params_dtype)
            )
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
512
513
        else:
            self.register_parameter("bias", None)
514
        self.update_param_tp_status()
515

516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
    def _maybe_allow_fp8_block_shape_mismatch(self) -> None:
        quant_config = getattr(self, "quant_config", None)
        weight_block = getattr(quant_config, "weight_block_size", None)
        if (
            weight_block is None
            or len(weight_block) < 1
            or len(self.output_partition_sizes) <= 1
        ):
            return

        try:
            block_n = int(weight_block[0])
        except (ValueError, TypeError):
            return

        if block_n <= 0:
            return

        if any(size % block_n != 0 for size in self.output_partition_sizes):
            self.allow_fp8_block_shape_mismatch = True
            logger.debug(
                "Allowing FP8 block shape mismatch for %s (block_n=%d, partitions=%s)",
                getattr(self, "prefix", "<unknown>"),
                block_n,
                self.output_partition_sizes,
            )

543
544
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        output_dim = getattr(param, "output_dim", None)
545

546
547
548
549
550
551
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

552
553
554
555
556
557
558
559
        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
560
561
            final_shape = list(loaded_weight.shape)
            if output_dim is not None:
562
                assert final_shape[output_dim] % self.tp_size == 0
563
                final_shape[output_dim] = final_shape[output_dim] // self.tp_size
564
            param.materialize(final_shape, dtype=loaded_weight.dtype)
565

566
        param_data = param.data
567
        if output_dim is not None and not is_sharded_weight:
568
            shard_size = param_data.shape[output_dim]
569
            start_idx = self.tp_rank * shard_size
570
            loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
571
572
573
574
575

        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)
576

577
578
579
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

580
    def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor):
581
582
583
584
585
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)
586
587
        param.load_column_parallel_weight(loaded_weight=loaded_weight)

588
    def forward(
589
590
        self,
        input_,
591
    ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
592
593
594
        bias = self.bias if not self.skip_bias_add else None

        # Matrix multiply.
595
        assert self.quant_method is not None
596
        output_parallel = self.quant_method.apply(self, input_, bias)
597

598
        if self.gather_output and self.tp_size > 1:
599
600
601
602
603
            # All-gather across the partitions.
            output = tensor_model_parallel_all_gather(output_parallel)
        else:
            output = output_parallel
        output_bias = self.bias if self.skip_bias_add else None
604
605
        if not self.return_bias:
            return output
606
607
        return output, output_bias

608
609
610
611
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size_per_partition}"
        s += f", bias={self.bias is not None}"
612
        s += f", tp_size={self.tp_size}"
613
614
615
        s += f", gather_output={self.gather_output}"
        return s

616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634

class MergedColumnParallelLinear(ColumnParallelLinear):
    """Packed linear layers with column parallelism.

    Similar to ColumnParallelLinear, but the weight matrix is concatenated
    along the output dimension. When the weight matrix is loaded, the
    different partitions are sharded separately.

    Args:
        input_size: input dimension of the linear layer.
        output_sizes: list of output dimensions of the linear layer.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make the output
                       available to all GPUs, otherwise, every GPU will have
                       its own output.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
635
        quant_config: Quantization configure.
636
637
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
638
        return_bias: If true, return bias together with outputs in forward pass.
639
640
        disable_tp: If true, all weights matrix won't be sharded, this layer
                    will be treated as a "Replicated" MergedLinear.
641
642
    """

643
644
645
646
647
648
649
    def __init__(
        self,
        input_size: int,
        output_sizes: list[int],
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
650
651
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
652
653
654
        prefix: str = "",
        *,
        return_bias: bool = True,
655
        disable_tp: bool = False,
656
    ):
657
        self.output_sizes = output_sizes
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0

        assert all(output_size % self.tp_size == 0 for output_size in output_sizes)
        super().__init__(
            input_size=input_size,
            output_size=sum(output_sizes),
            bias=bias,
            gather_output=gather_output,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            quant_config=quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
James Fleming's avatar
James Fleming committed
674

675
676
677
678
    def weight_loader(
        self,
        param: Parameter,
        loaded_weight: torch.Tensor,
679
        loaded_shard_id: int | None = None,
680
    ):
681
682
683
684
685
        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
686
687
688
689
690
            if loaded_shard_id is not None:
                param.data[loaded_shard_id].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
                param.shard_weight_type = {
691
                    i: loaded_weight.item() for i, _ in enumerate(self.output_sizes)
692
                }
693
694
            return

695
696
        if is_gguf_weight:
            output_dim = getattr(param, "output_dim", None)
697
698
            shard_size = loaded_weight.size(output_dim) // self.tp_size
            start_idx = self.tp_rank * shard_size
699

700
            if loaded_shard_id is not None:
701
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
702
703
704
705
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
706

707
708
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
709
710
        # Special case for per-tensor scale to load scalar into fused array.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
711

712
        if loaded_shard_id is None:
713
714
            # Loaded weight is already fused on disk (mlp).
            # (e.g., Phi-3's gate_up_proj).
715
            if output_dim is None:
716
                if needs_scalar_to_array:
717
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
718
719
                        param_data, loaded_weight, 0
                    )
720

721
722
723
724
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            current_shard_offset = 0
725
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
726
            shard_offsets: list[tuple[int, int, int]] = []
727
728
729
730
731
            for i, output_size in enumerate(self.output_sizes):
                shard_offsets.append((i, current_shard_offset, output_size))
                current_shard_offset += output_size
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
732
                # Special case for Quantization.
733
734
735
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
736
737
                    shard_size = shard_size // param.packed_factor
                    shard_offset = shard_offset // param.packed_factor
738
                    # Special case for Marlin.
739
                    shard_size, shard_offset = adjust_marlin_shard(
740
741
                        param, shard_size, shard_offset
                    )
742

743
                shard_size, shard_offset = adjust_bitblas_shard(
744
745
                    param, shard_size, shard_offset
                )
746

747
                if use_bitsandbytes_4bit:
748
749
750
751
752
753
754
                    index = list(itertools.accumulate([0] + self.output_sizes))
                    orig_offsets = {
                        str(i): (index[i], size)
                        for i, size in enumerate(self.output_sizes)
                    }
                    orig_offsets["total"] = (self.output_size, 0)
                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
755
756
                        param, orig_offsets, str(shard_id)
                    )
757

758
                loaded_weight_shard = loaded_weight.narrow(
759
760
                    output_dim, shard_offset, shard_size
                )
761
762
763
764
765
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id < len(self.output_sizes)
        if output_dim is not None:
766
            shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
767
            shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
768
            # Special case for quantization.
769
770
771
772
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
773
774
                shard_size = shard_size // param.packed_factor
                shard_offset = shard_offset // param.packed_factor
775
                # Special case for Marlin.
776
                shard_size, shard_offset = adjust_marlin_shard(
777
778
                    param, shard_size, shard_offset
                )
779
            shard_size, shard_offset = adjust_bitblas_shard(
780
781
                param, shard_size, shard_offset
            )
782

783
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
784
785
786
787
788
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

789
            if use_bitsandbytes_4bit:
790
                shard_size = loaded_weight.shape[output_dim]
791
                shard_offset = loaded_weight.shape[output_dim] * loaded_shard_id
792

793
            param_data = param_data.narrow(output_dim, shard_offset, shard_size)
794
            start_idx = self.tp_rank * shard_size
795
            if not is_sharded_weight:
796
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
797
798
799
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
800
801
                param_data, loaded_weight, loaded_shard_id
            )
802

803
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
804
805
806
807
808
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "MergedColumnParallelLinear, assume the weight is "
809
810
                    "the same for all partitions."
                )
811

812
813
814
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

815
816
817
    def _load_fused_module_from_checkpoint(
        self, param: BasevLLMParameter, loaded_weight: torch.Tensor
    ):
818
819
820
        """
        Handle special case for models where MLP layers are already
        fused on disk. In this case, we have no shard id. This function
821
        determines the shard id by splitting these layers and then calls
822
823
824
825
826
827
828
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """

        current_shard_offset = 0
829
        shard_offsets: list[tuple[int, int, int]] = []
830
831
832
833
834
835
836
837
        for i, output_size in enumerate(self.output_sizes):
            shard_offsets.append((i, current_shard_offset, output_size))
            current_shard_offset += output_size

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
838
839
840
841
842
843
844
845
846
847
848
            if (
                isinstance(param, (PackedColumnParameter, PackedvLLMParameter))
                and param.packed_dim == param.output_dim
            ):
                shard_size, shard_offset = param.adjust_shard_indexes_for_packing(
                    shard_size=shard_size, shard_offset=shard_offset
                )

            loaded_weight_shard = loaded_weight.narrow(
                param.output_dim, shard_offset, shard_size
            )
849
850
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

851
852
853
854
    def weight_loader_v2(
        self,
        param: BasevLLMParameter,
        loaded_weight: torch.Tensor,
855
        loaded_shard_id: int | None = None,
856
    ):
857
        if loaded_shard_id is None:
858
            if isinstance(param, PerTensorScaleParameter):
859
                param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0)
860
                return
861
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
862
                param.load_merged_column_weight(loaded_weight=loaded_weight)
863
                return
864
            # TODO: @dsikka - move to parameter.py
865
866
867
868
869
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id < len(self.output_sizes)

870
871
        if isinstance(param, BlockQuantScaleParameter):
            assert self.quant_method is not None
872
873
874
            # Assume the weight block size has been set by quant method
            assert hasattr(self, "weight_block_size")
            weight_block_size = self.weight_block_size
875
876
877
            assert weight_block_size is not None
            block_n, _ = weight_block_size[0], weight_block_size[1]
            shard_offset = (
878
879
880
881
882
883
884
                (sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) // block_n
            ) // self.tp_size
            shard_size = (
                (self.output_sizes[loaded_shard_id] + block_n - 1)
                // block_n
                // self.tp_size
            )
885
        else:
886
            shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
887
            shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
888

889
890
891
892
893
894
895
        param.load_merged_column_weight(
            loaded_weight=loaded_weight,
            shard_id=loaded_shard_id,
            shard_offset=shard_offset,
            shard_size=shard_size,
            tp_rank=self.tp_rank,
        )
896

897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918

class QKVParallelLinear(ColumnParallelLinear):
    """Linear layers for the attention's QKV transformation.

    Linear layers for the linear transformation of the query, key, and value
    vectors in the attention layer. The weight matrix is concatenated along
    the output dimension. The layer is parallelized along the head dimension.
    When the number of key/value heads is smaller than the number of query
    heads (e.g., multi-query/grouped-query attention), the key/value head may
    be replicated while the query heads are partitioned.

    Args:
        hidden_size: input hidden state size of the transformer.
        head_size: size of each attention head.
        total_num_heads: total number of attention query heads.
        total_num_kv_heads: total number of attention key/value heads. If
                            None, assume total_num_kv_heads = total_num_heads.
        bias: If true, add bias.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
919
        quant_config: Quantization configure.
920
921
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
922
        return_bias: If true, return bias together with outputs in forward pass.
923
        disable_tp: If true, weights matrix won't be sharded through tp rank.
924
925
    """

926
927
928
929
930
    def __init__(
        self,
        hidden_size: int,
        head_size: int,
        total_num_heads: int,
931
        total_num_kv_heads: int | None = None,
932
933
        bias: bool = True,
        skip_bias_add: bool = False,
934
935
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
936
937
938
        prefix: str = "",
        *,
        return_bias: bool = True,
939
        disable_tp: bool = False,
940
        v_head_size: int | None = None,
941
    ):
942
943
        self.hidden_size = hidden_size
        self.head_size = head_size
944
        self.v_head_size = v_head_size if v_head_size is not None else head_size
945
946
947
948
949
        self.total_num_heads = total_num_heads
        if total_num_kv_heads is None:
            total_num_kv_heads = total_num_heads
        self.total_num_kv_heads = total_num_kv_heads
        # Divide the weight matrix along the last dimension.
950
        tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
951
952
953
        self.num_heads = divide(self.total_num_heads, tp_size)
        if tp_size >= self.total_num_kv_heads:
            self.num_kv_heads = 1
954
            self.num_kv_head_replicas = divide(tp_size, self.total_num_kv_heads)
955
956
957
958
        else:
            self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
            self.num_kv_head_replicas = 1
        input_size = self.hidden_size
959
        output_size = (
960
961
962
963
            self.num_heads * self.head_size
            + self.num_kv_heads * self.head_size
            + self.num_kv_heads * self.v_head_size
        ) * tp_size
964
965
966
        self.output_sizes = [
            self.num_heads * self.head_size * tp_size,  # q_proj
            self.num_kv_heads * self.head_size * tp_size,  # k_proj
967
            self.num_kv_heads * self.v_head_size * tp_size,  # v_proj
James Fleming's avatar
James Fleming committed
968
969
        ]

970
971
972
973
974
975
976
977
978
979
980
981
        super().__init__(
            input_size=input_size,
            output_size=output_size,
            bias=bias,
            gather_output=False,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            quant_config=quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
982

983
984
985
986
987
    def _get_shard_offset_mapping(self, loaded_shard_id: str):
        shard_offset_mapping = {
            "q": 0,
            "k": self.num_heads * self.head_size,
            "v": (self.num_heads + self.num_kv_heads) * self.head_size,
988
989
            "total": (self.num_heads + self.num_kv_heads) * self.head_size
            + self.num_kv_heads * self.v_head_size,
990
991
992
993
994
995
996
        }
        return shard_offset_mapping.get(loaded_shard_id)

    def _get_shard_size_mapping(self, loaded_shard_id: str):
        shard_size_mapping = {
            "q": self.num_heads * self.head_size,
            "k": self.num_kv_heads * self.head_size,
997
            "v": self.num_kv_heads * self.v_head_size,
998
999
1000
        }
        return shard_size_mapping.get(loaded_shard_id)

1001
1002
1003
    def _load_fused_module_from_checkpoint(
        self, param: BasevLLMParameter, loaded_weight: torch.Tensor
    ):
1004
        """
1005
        Handle special case for models where QKV layers are already
1006
        fused on disk. In this case, we have no shard id. This function
1007
        determines the shard id by splitting these layers and then calls
1008
1009
1010
1011
1012
1013
1014
1015
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """
        shard_offsets = [
            # (shard_id, shard_offset, shard_size)
            ("q", 0, self.total_num_heads * self.head_size),
1016
1017
1018
1019
1020
1021
1022
1023
            (
                "k",
                self.total_num_heads * self.head_size,
                self.total_num_kv_heads * self.head_size,
            ),
            (
                "v",
                (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
1024
                self.total_num_kv_heads * self.v_head_size,
1025
            ),
1026
1027
1028
1029
1030
1031
        ]

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
            if (
                isinstance(param, (PackedColumnParameter, PackedvLLMParameter))
                and param.packed_dim == param.output_dim
            ):
                shard_size, shard_offset = param.adjust_shard_indexes_for_packing(
                    shard_size=shard_size, shard_offset=shard_offset
                )

            loaded_weight_shard = loaded_weight.narrow(
                param.output_dim, shard_offset, shard_size
            )
1043
1044
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

1045
1046
1047
1048
    def weight_loader_v2(
        self,
        param: BasevLLMParameter,
        loaded_weight: torch.Tensor,
1049
        loaded_shard_id: str | None = None,
1050
    ):
1051
        if loaded_shard_id is None:  # special case for certain models
1052
            if isinstance(param, PerTensorScaleParameter):
1053
1054
1055
                param.load_qkv_weight(
                    loaded_weight=loaded_weight, shard_id=0, tp_rank=self.tp_rank
                )
1056
                return
1057
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
1058
                param.load_qkv_weight(loaded_weight=loaded_weight, tp_rank=self.tp_rank)
1059
                return
1060
            # TODO: @dsikka - move to parameter.py
1061
1062
1063
1064
1065
1066
1067
1068
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id in ["q", "k", "v"]

        shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
        shard_size = self._get_shard_size_mapping(loaded_shard_id)

1069
1070
1071
        # Note(simon): This is needed for Qwen3's fp8 quantization.
        if isinstance(param, BlockQuantScaleParameter):
            assert self.quant_method is not None
1072
1073
1074
1075
            # Assume the weight block size has been set by quant method
            assert hasattr(self, "weight_block_size")
            weight_block_size = self.weight_block_size
            assert weight_block_size is not None
1076
1077
1078
1079
            block_n, _ = weight_block_size[0], weight_block_size[1]
            shard_offset = (shard_offset + block_n - 1) // block_n
            shard_size = (shard_size + block_n - 1) // block_n

1080
1081
1082
1083
1084
1085
1086
1087
        param.load_qkv_weight(
            loaded_weight=loaded_weight,
            num_heads=self.num_kv_head_replicas,
            shard_id=loaded_shard_id,
            shard_offset=shard_offset,
            shard_size=shard_size,
            tp_rank=self.tp_rank,
        )
1088

1089
1090
1091
1092
    def weight_loader(
        self,
        param: Parameter,
        loaded_weight: torch.Tensor,
1093
        loaded_shard_id: str | None = None,
1094
    ):
1095
1096
1097
1098
        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
1099
        if is_gguf_weight_type:
1100
            idx_map = {"q": 0, "k": 1, "v": 2}
1101
1102
1103
1104
            if loaded_shard_id is not None:
                param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
1105
                param.shard_weight_type = {k: loaded_weight.item() for k in idx_map}
1106
1107
            return

1108
1109
        if is_gguf_weight:
            output_dim = getattr(param, "output_dim", None)
1110
1111
            shard_size = loaded_weight.size(output_dim) // self.tp_size
            start_idx = self.tp_rank * shard_size
1112

1113
            if loaded_shard_id is not None:
1114
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
1115
1116
1117
1118
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
1119

1120
1121
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
1122

1123
1124
        # Special case for per-tensor scales in fused case.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
1125

1126
        if loaded_shard_id is None:
1127
1128
            # Loaded weight is already fused on disk (qkv).
            # (e.g., Phi-3's qkv_proj).
1129
            if output_dim is None:
1130
                if needs_scalar_to_array:
1131
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
1132
1133
                        param_data, loaded_weight, 0
                    )
1134

1135
1136
1137
1138
1139
1140
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            shard_offsets = [
                # (shard_id, shard_offset, shard_size)
                ("q", 0, self.total_num_heads * self.head_size),
1141
1142
1143
1144
1145
1146
1147
1148
                (
                    "k",
                    self.total_num_heads * self.head_size,
                    self.total_num_kv_heads * self.head_size,
                ),
                (
                    "v",
                    (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
1149
                    self.total_num_kv_heads * self.v_head_size,
1150
                ),
1151
            ]
1152
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1153

1154
1155
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
1156
                # Special case for Quantized Weights.
1157
1158
1159
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
1160
1161
                    shard_size = shard_size // param.packed_factor
                    shard_offset = shard_offset // param.packed_factor
1162

1163
                    # Special case for Marlin.
1164
                    shard_size, shard_offset = adjust_marlin_shard(
1165
1166
                        param, shard_size, shard_offset
                    )
1167

1168
1169
1170
                if use_bitsandbytes_4bit:
                    orig_qkv_offsets = {
                        "q": (0, self.total_num_heads * self.head_size),
1171
1172
1173
1174
1175
1176
1177
                        "k": (
                            self.total_num_heads * self.head_size,
                            self.total_num_kv_heads * self.head_size,
                        ),
                        "v": (
                            (self.total_num_heads + self.total_num_kv_heads)
                            * self.head_size,
1178
                            self.total_num_kv_heads * self.v_head_size,
1179
1180
                        ),
                        "total": (
1181
1182
1183
                            (self.total_num_heads + self.total_num_kv_heads)
                            * self.head_size
                            + self.total_num_kv_heads * self.v_head_size,
1184
1185
                            0,
                        ),
1186
1187
1188
                    }

                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
1189
1190
                        param, orig_qkv_offsets, shard_id
                    )
1191

1192
                loaded_weight_shard = loaded_weight.narrow(
1193
1194
                    output_dim, shard_offset, shard_size
                )
1195
1196
1197
1198
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id in ["q", "k", "v"]
1199
1200

        # If output dim is defined, use the default loading process.
1201
1202
1203
1204
1205
1206
1207
1208
        if output_dim is not None:
            if loaded_shard_id == "q":
                shard_offset = 0
                shard_size = self.num_heads * self.head_size
            elif loaded_shard_id == "k":
                shard_offset = self.num_heads * self.head_size
                shard_size = self.num_kv_heads * self.head_size
            elif loaded_shard_id == "v":
1209
                shard_offset = (self.num_heads + self.num_kv_heads) * self.head_size
1210
                shard_size = self.num_kv_heads * self.v_head_size
1211
            # Special case for Quantized Weights.
1212
1213
1214
1215
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
1216
1217
                shard_size = shard_size // param.packed_factor
                shard_offset = shard_offset // param.packed_factor
1218

1219
                # Special case for Marlin.
1220
                shard_size, shard_offset = adjust_marlin_shard(
1221
1222
                    param, shard_size, shard_offset
                )
1223

1224
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1225
1226
1227
1228
1229
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

1230
            if use_bitsandbytes_4bit:
1231
1232
                orig_qkv_offsets = {
                    "q": (0, self.num_heads * self.head_size),
1233
1234
1235
1236
1237
1238
                    "k": (
                        self.num_heads * self.head_size,
                        self.num_kv_heads * self.head_size,
                    ),
                    "v": (
                        (self.num_heads + self.num_kv_heads) * self.head_size,
1239
                        self.num_kv_heads * self.v_head_size,
1240
1241
                    ),
                    "total": (
1242
1243
                        (self.num_heads + self.num_kv_heads) * self.head_size
                        + self.num_kv_heads * self.v_head_size,
1244
1245
                        0,
                    ),
1246
                }
1247
                shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
1248
1249
                    param, orig_qkv_offsets, loaded_shard_id
                )
1250

1251
            param_data = param_data.narrow(output_dim, shard_offset, shard_size)
1252
            if loaded_shard_id == "q":
1253
                shard_rank = self.tp_rank
1254
            else:
1255
1256
                shard_rank = self.tp_rank // self.num_kv_head_replicas
            start_idx = shard_rank * shard_size
1257

1258
            if not is_sharded_weight:
1259
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
1260

1261
1262
1263
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
1264
1265
                param_data, loaded_weight, loaded_shard_id
            )
1266
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
1267
1268
1269
1270
1271
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "QKVParallelLinear, assume the weight is the same "
1272
1273
                    "for all partitions."
                )
1274

1275
1276
1277
1278
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)


1279
@CustomOp.register("row_parallel_linear")
1280
class RowParallelLinear(LinearBase):
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
    """Linear layer with row parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its first dimension and X along its second dimension as:
               -   -
              | A_1 |
              | .   |
          A = | .   |        X = [X_1, ..., X_p]
              | .   |
              | A_p |
               -   -
    Arguments:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias. Note that bias is not parallelized.
        input_is_parallel: If true, we assume that the input is already
                           split across the GPUs and we do not split
                           again.
        skip_bias_add: This was added to enable performance optimization where
                       bias can be fused with other element-wise operations.
                       We skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
1303
1304
1305
        reduce_results: If true, call all-reduce on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y = X_iA_i
1306
        quant_config: Quantization configure.
1307
1308
1309
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.down_proj)
        return_bias: If true, return bias together with outputs in forward pass.
1310
        disable_tp: If true, weights matrix won't be sharded through tp rank.
1311
1312
    """

1313
1314
1315
1316
1317
1318
1319
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        input_is_parallel: bool = True,
        skip_bias_add: bool = False,
1320
        params_dtype: torch.dtype | None = None,
1321
        reduce_results: bool = True,
1322
        quant_config: QuantizationConfig | None = None,
1323
1324
1325
        prefix: str = "",
        *,
        return_bias: bool = True,
1326
        disable_tp: bool = False,
1327
    ):
1328
        # Divide the weight matrix along the first dimension.
1329
1330
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
1331
1332
1333
1334
        self.input_size_per_partition = divide(input_size, self.tp_size)
        self.output_size_per_partition = output_size
        self.output_partition_sizes = [output_size]

1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
        super().__init__(
            input_size,
            output_size,
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
1345

1346
1347
1348
        self.input_is_parallel = input_is_parallel
        self.reduce_results = reduce_results

1349
        assert self.quant_method is not None
1350
1351
1352
        self.quant_method.create_weights(
            layer=self,
            input_size_per_partition=self.input_size_per_partition,
1353
            output_partition_sizes=self.output_partition_sizes,
1354
1355
1356
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
1357
            weight_loader=(
1358
1359
1360
1361
1362
                self.weight_loader_v2
                if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
                else self.weight_loader
            ),
        )
1363
        if not reduce_results and (bias and not skip_bias_add):
1364
1365
1366
1367
            raise ValueError(
                "When not reduce the results, adding bias to the "
                "results can lead to incorrect results"
            )
1368
1369

        if bias:
1370
1371
1372
1373
1374
1375
1376
1377
            self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype))
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
1378
1379
        else:
            self.register_parameter("bias", None)
1380
        self.update_param_tp_status()
1381
1382
1383

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        input_dim = getattr(param, "input_dim", None)
1384
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1385
1386
1387
1388
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
1389
1390
1391
1392
1393
1394
1395
1396
1397

        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
1398
1399
            weight_shape = list(loaded_weight.shape)
            if input_dim:
1400
                weight_shape[input_dim] = weight_shape[input_dim] // self.tp_size
1401
            param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
1402

1403
        param_data = param.data
1404
        if input_dim is not None and not is_sharded_weight:
1405
            shard_size = param_data.shape[input_dim]
1406
            start_idx = self.tp_rank * shard_size
1407
            loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)
1408

1409
1410
1411
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
1412
1413
            loaded_weight = loaded_weight.reshape(1)

1414
1415
1416
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

1417
    def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor):
1418
1419
1420
1421
1422
1423
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)

1424
1425
        param.load_row_parallel_weight(loaded_weight=loaded_weight)

1426
    def forward(
1427
1428
        self,
        input_,
1429
    ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
1430
1431
1432
1433
        if self.input_is_parallel:
            input_parallel = input_
        else:
            splitted_input = split_tensor_along_last_dim(
1434
1435
                input_, num_partitions=self.tp_size
            )
1436
            input_parallel = splitted_input[self.tp_rank].contiguous()
1437
1438

        # Matrix multiply.
1439
        assert self.quant_method is not None
1440
1441
1442
        # Only fuse bias add into GEMM for rank 0 (this ensures that
        # bias will not get added more than once in TP>1 case)
        bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
1443
1444
        output_parallel = self.quant_method.apply(self, input_parallel, bias_)

1445
        if self.reduce_results and self.tp_size > 1:
1446
            output = tensor_model_parallel_all_reduce(output_parallel)
1447
        else:
1448
1449
1450
            output = output_parallel

        output_bias = self.bias if self.skip_bias_add else None
1451

1452
1453
        if not self.return_bias:
            return output
1454
        return output, output_bias
1455
1456

    def extra_repr(self) -> str:
1457
        s = f"in_features={self.input_size_per_partition}"
1458
1459
1460
1461
1462
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={self.tp_size}"
        s += f", reduce_results={self.reduce_results}"
        return s