linear.py 59.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
from abc import abstractmethod
6
7

import torch
8
from torch.nn.parameter import Parameter, UninitializedParameter
9

10
11
12
13
14
15
16
17
from vllm.distributed import (
    divide,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
    split_tensor_along_last_dim,
    tensor_model_parallel_all_gather,
    tensor_model_parallel_all_reduce,
)
18
from vllm.logger import init_logger
19
from vllm.model_executor.custom_op import PluggableLayer
20
21
22
23
from vllm.model_executor.layers.batch_invariant import (
    linear_batch_invariant,
    vllm_is_batch_invariant,
)
24
from vllm.model_executor.layers.quantization.base_config import (
25
26
27
    QuantizationConfig,
    QuantizeMethodBase,
)
28
29
30
from vllm.model_executor.layers.utils import (
    dispatch_unquantized_gemm,
)
31
32
33
34
35
36
37
38
39
from vllm.model_executor.parameter import (
    BasevLLMParameter,
    BlockQuantScaleParameter,
    ModelWeightParameter,
    PackedColumnParameter,
    PackedvLLMParameter,
    PerTensorScaleParameter,
    RowvLLMParameter,
)
40
from vllm.model_executor.utils import set_weight_attrs
41
from vllm.platforms import current_platform
42
43
44

logger = init_logger(__name__)

45
WEIGHT_LOADER_V2_SUPPORTED = [
46
    "UnquantizedLinearMethod",
47
    "CompressedTensorsLinearMethod",
48
    "CompressedTensorsLinearTransformMethod",
49
50
51
52
53
54
55
56
57
58
    "AWQMarlinLinearMethod",
    "AWQLinearMethod",
    "GPTQMarlinLinearMethod",
    "Fp8LinearMethod",
    "MarlinLinearMethod",
    "GPTQMarlin24LinearMethod",
    "TPUInt8LinearMethod",
    "GPTQLinearMethod",
    "FBGEMMFp8LinearMethod",
    "ModelOptFp8LinearMethod",
59
60
    "ModelOptFp8PcPtLinearMethod",
    "ModelOptFp8PbWoLinearMethod",
61
62
    "QuarkLinearMethod",
    "ModelOptNvFp4LinearMethod",
63
    "PetitNvFp4LinearMethod",
64
]
65

66

67
68
69
70
71
72
def adjust_marlin_shard(
    param: Parameter,
    shard_size: int,
    shard_offset: int,
) -> tuple[int, int]:
    marlin_tile_size: int | None = getattr(param, "marlin_tile_size", None)
73
74
75
76
77
78
    if marlin_tile_size is None:
        return shard_size, shard_offset

    return shard_size * marlin_tile_size, shard_offset * marlin_tile_size


79
80
81
82
83
def adjust_block_scale_shard(
    weight_block_size: tuple[int, ...] | None,
    shard_size: int,
    shard_offset: int,
) -> tuple[int, int]:
84
85
86
87
88
89
90
    assert weight_block_size is not None
    block_n = weight_block_size[0]
    shard_offset = (shard_offset + block_n - 1) // block_n
    shard_size = (shard_size + block_n - 1) // block_n
    return shard_size, shard_offset


91
def adjust_bitsandbytes_4bit_shard(
92
93
94
    param: Parameter,
    shard_offsets: dict[str, tuple[int, int]],
    loaded_shard_id: str,
95
) -> tuple[int, int]:
96
97
    """Adjust the quantization offsets and sizes for BitsAndBytes sharding."""

98
99
    total, _ = shard_offsets["total"]
    orig_offset, orig_size = shard_offsets[loaded_shard_id]
100
101
102
103
104
105
106
107

    quantized_total = param.data.shape[0]
    quantized_offset = orig_offset * quantized_total // total
    quantized_size = orig_size * quantized_total // total

    return quantized_size, quantized_offset


108
109
110
111
112
def adjust_scalar_to_fused_array(
    param_data: torch.Tensor,
    loaded_weight: torch.Tensor,
    shard_id: int | str,
) -> tuple[torch.Tensor, torch.Tensor]:
113
114
    """For fused modules (QKV and MLP) we have an array of length
    N that holds 1 scale for each "logical" matrix. So the param
115
116
    is an array of length N. The loaded_weight corresponds to
    one of the shards on disk. Here, we slice the param based on
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
    the shard_id for loading.
    """
    qkv_idxs = {"q": 0, "k": 1, "v": 2}

    if isinstance(shard_id, str):
        shard_id = qkv_idxs[shard_id]
    elif not isinstance(shard_id, int):
        raise ValueError(f"Unknown Shard Id {shard_id}")

    # AutoFP8 scales do not have a shape
    # compressed-tensors scales do have a shape
    if len(loaded_weight.shape) != 0:
        assert loaded_weight.shape[0] == 1
        loaded_weight = loaded_weight[0]

132
    return param_data[shard_id], loaded_weight
133
134


135
class LinearMethodBase(QuantizeMethodBase):
136
137
138
    """Base class for different (maybe quantized) linear methods."""

    @abstractmethod
139
140
141
142
143
144
145
146
147
148
149
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        """Create weights for a linear layer.
150
           The weights will be set as attributes of the layer.
151

152
153
154
        Args:
            layer: The layer that is using the LinearMethodBase factory.
            input_size_per_partition: Size of the weight input dim on rank X.
155
            output_partition_sizes: Sizes of the output dim of each logical
156
157
158
159
160
161
                weight on rank X. E.g., output_partition_sizes for QKVLinear
                is a list contains the width of Wq, Wk, Wv on rank X.
            input_size: Size of the input dim of the weight across all ranks.
            output_size: Size of the output dim of the weight across all ranks.
            params_dtype: Datatype of the parameters.
        """
162
163
164
        raise NotImplementedError

    @abstractmethod
165
166
167
168
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
169
        bias: torch.Tensor | None = None,
170
    ) -> torch.Tensor:
171
172
        """Apply the weights in layer to the input tensor.
        Expects create_weights to have been called before on the layer."""
173
174
175
176
        raise NotImplementedError


class UnquantizedLinearMethod(LinearMethodBase):
177
    """Linear method without quantization."""
178

179
180
181
182
183
184
185
186
187
188
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
189
190
191
192
        # This method creates unquantized linear weights.
        # The weights are not quantized, and they are not sharded.
        # The amount of memory allocated for the weights is
        # sum(output_partition_sizes) * input_size_per_partition.
193
194
195
196
197
198
199
200
201
202
203
        weight_loader = extra_weight_attrs.pop("weight_loader")
        weight = ModelWeightParameter(
            data=torch.empty(
                sum(output_partition_sizes),
                input_size_per_partition,
                dtype=params_dtype,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
204

205
206
        layer.register_parameter("weight", weight)
        set_weight_attrs(weight, extra_weight_attrs)
207

208
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
209
        if current_platform.is_cpu():
210
            from vllm.model_executor.layers.utils import dispatch_cpu_unquantized_gemm
211

212
            dispatch_cpu_unquantized_gemm(layer, remove_weight=True)
213

214
215
216
217
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
218
        bias: torch.Tensor | None = None,
219
    ) -> torch.Tensor:
220
        if vllm_is_batch_invariant() and current_platform.is_cuda_alike():
221
            return linear_batch_invariant(x, layer.weight, bias)
222
        return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
223
224


225
class LinearBase(PluggableLayer):
226
    """Base linear layer.
227
228
229
230
231
232

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
233
        quant_config: Quantization configure.
234
        prefix: Prefix for parameter names.
235
        return_bias: If true, return bias together with outputs in forward pass.
236
        disable_tp: If true, tensor parallelism will be disabled for this layer.
237
238
239
240
241
242
243
    """

    def __init__(
        self,
        input_size: int,
        output_size: int,
        skip_bias_add: bool = False,
244
245
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
246
        prefix: str = "",
247
248
        *,
        return_bias: bool = True,
249
        disable_tp: bool = False,
250
251
252
253
254
255
256
257
258
259
    ):
        super().__init__()

        # Keep input parameters
        self.input_size = input_size
        self.output_size = output_size
        self.skip_bias_add = skip_bias_add
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype
260
261
        self.quant_config = quant_config
        self.prefix = prefix
262
        self.allow_fp8_block_shape_mismatch = False
263
        if quant_config is None:
264
            self.quant_method: QuantizeMethodBase | None = UnquantizedLinearMethod()
265
        else:
266
            self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
267
        self.return_bias = return_bias
268
        self.disable_tp = disable_tp
269
270
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
271

272
    def update_param_tp_status(self):
273
274
275
276
        for param in self.parameters():
            if isinstance(param, BasevLLMParameter):
                param.tp_rank = self.tp_rank
                param.tp_size = self.tp_size
277
278


279
# --8<-- [start:replicated_linear]
280
@PluggableLayer.register("replicated_linear")
281
282
283
284
285
286
287
288
289
290
class ReplicatedLinear(LinearBase):
    """Replicated linear layer.

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        bias: If true, add bias.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
        quant_config: Quantization configure.
291
292
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
293
        return_bias: If true, return bias together with outputs in forward pass.
294
        disable_tp: Take no effect for replicated linear layers.
295
296
    """

297
298
    # --8<-- [end:replicated_linear]

299
300
301
302
303
304
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        skip_bias_add: bool = False,
305
306
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
307
308
309
        prefix: str = "",
        *,
        return_bias: bool = True,
310
        disable_tp: bool = False,
311
    ):
312
313
314
315
316
317
        # If MergedReplicatedLinear, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = self.output_sizes
        else:
            self.output_partition_sizes = [output_size]

318
319
320
321
322
323
324
325
326
327
        super().__init__(
            input_size,
            output_size,
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
328

329
330
        # All the linear layer supports quant method.
        assert self.quant_method is not None
331
332
333
334
335
336
337
338
339
        self.quant_method.create_weights(
            self,
            self.input_size,
            self.output_partition_sizes,
            self.input_size,
            self.output_size,
            self.params_dtype,
            weight_loader=self.weight_loader,
        )
340

341
342
        if bias:
            self.bias = Parameter(
343
344
345
346
347
348
349
350
351
                torch.empty(self.output_size, dtype=self.params_dtype)
            )
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
352
353
354
        else:
            self.register_parameter("bias", None)

355
356
357
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        # If the weight on disk does not have a shape, give it one
        # (such scales for AutoFp8).
358
359
360
361
362
363
364
365
366
367
368
        # Special case for GGUF

        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
            param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)

369
370
371
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)

372
373
        assert param.size() == loaded_weight.size(), (
            f"Tried to load weights of size {loaded_weight.size()}"
374
375
            f"to a parameter of size {param.size()}"
        )
376
377
        param.data.copy_(loaded_weight)

378
    def forward(
379
380
        self,
        x: torch.Tensor,
381
    ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
382
        bias = self.bias if not self.skip_bias_add else None
383
        assert self.quant_method is not None
384

385
        output = self.quant_method.apply(self, x, bias)
386

387
388
        if not self.return_bias:
            return output
389
        output_bias = self.bias if self.skip_bias_add else None
390
391
        return output, output_bias

392
393
394
395
396
397
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        return s

398

399
# --8<-- [start:column_parallel_linear]
400
@PluggableLayer.register("column_parallel_linear")
401
class ColumnParallelLinear(LinearBase):
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
    """Linear layer with column parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its second dimension as A = [A_1, ..., A_p].

    Args:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y_i = XA_i
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
418
        quant_config: Quantization configure.
419
        prefix: The name of the layer in the state dict, including all parents
420
                        (e.g. model.layers.0.qkv_proj)
421
422
        return_bias: If true, return bias together with outputs in forward pass.
        disable_tp: If true, weights matrix won't be sharded through tp rank.
423
424
    """

425
426
    # --8<-- [end:column_parallel_linear]

427
428
429
430
431
432
433
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
434
435
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
436
437
438
        prefix: str = "",
        *,
        return_bias: bool = True,
439
        disable_tp: bool = False,
440
    ):
441
        # Divide the weight matrix along the last dimension.
442
443
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
444
445
        self.input_size_per_partition = input_size
        self.output_size_per_partition = divide(output_size, self.tp_size)
446
447
448
449
        self.output_partition_sizes = [self.output_size_per_partition]
        # If QKV or MergedColumn, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = [
450
                divide(output_size, self.tp_size) for output_size in self.output_sizes
451
452
            ]

453
454
455
456
457
458
459
460
461
462
        super().__init__(
            input_size,
            output_size,
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
463

464
        self._maybe_allow_fp8_block_shape_mismatch()
465
466
467
        self.gather_output = gather_output

        assert self.quant_method is not None
468
469
        self.quant_method.create_weights(
            layer=self,
470
            input_size_per_partition=self.input_size_per_partition,
471
472
473
474
            output_partition_sizes=self.output_partition_sizes,
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
475
            weight_loader=(
476
477
478
479
480
                self.weight_loader_v2
                if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
                else self.weight_loader
            ),
        )
481
482
        if bias:
            self.bias = Parameter(
483
484
485
486
487
488
489
490
491
                torch.empty(self.output_size_per_partition, dtype=params_dtype)
            )
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
492
493
        else:
            self.register_parameter("bias", None)
494
        self.update_param_tp_status()
495

496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
    def _maybe_allow_fp8_block_shape_mismatch(self) -> None:
        quant_config = getattr(self, "quant_config", None)
        weight_block = getattr(quant_config, "weight_block_size", None)
        if (
            weight_block is None
            or len(weight_block) < 1
            or len(self.output_partition_sizes) <= 1
        ):
            return

        try:
            block_n = int(weight_block[0])
        except (ValueError, TypeError):
            return

        if block_n <= 0:
            return

        if any(size % block_n != 0 for size in self.output_partition_sizes):
            self.allow_fp8_block_shape_mismatch = True
            logger.debug(
                "Allowing FP8 block shape mismatch for %s (block_n=%d, partitions=%s)",
                getattr(self, "prefix", "<unknown>"),
                block_n,
                self.output_partition_sizes,
            )

523
524
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        output_dim = getattr(param, "output_dim", None)
525

526
527
528
529
530
531
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

532
533
534
535
536
537
538
539
        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
540
541
            final_shape = list(loaded_weight.shape)
            if output_dim is not None:
542
                assert final_shape[output_dim] % self.tp_size == 0
543
                final_shape[output_dim] = final_shape[output_dim] // self.tp_size
544
            param.materialize(final_shape, dtype=loaded_weight.dtype)
545

546
        param_data = param.data
547
        if output_dim is not None and not is_sharded_weight:
548
            shard_size = param_data.shape[output_dim]
549
            start_idx = self.tp_rank * shard_size
550
            loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
551
552
553
554
555

        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)
556

557
558
559
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

560
    def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor):
561
562
563
564
565
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)
566
567
        param.load_column_parallel_weight(loaded_weight=loaded_weight)

568
    def forward(
569
570
        self,
        input_,
571
    ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
572
573
574
        bias = self.bias if not self.skip_bias_add else None

        # Matrix multiply.
575
        assert self.quant_method is not None
576
        output_parallel = self.quant_method.apply(self, input_, bias)
577

578
        if self.gather_output and self.tp_size > 1:
579
580
581
582
            # All-gather across the partitions.
            output = tensor_model_parallel_all_gather(output_parallel)
        else:
            output = output_parallel
583

584
585
        if not self.return_bias:
            return output
586
        output_bias = self.bias if self.skip_bias_add else None
587
588
        return output, output_bias

589
590
591
592
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size_per_partition}"
        s += f", bias={self.bias is not None}"
593
        s += f", tp_size={self.tp_size}"
594
595
596
        s += f", gather_output={self.gather_output}"
        return s

597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615

class MergedColumnParallelLinear(ColumnParallelLinear):
    """Packed linear layers with column parallelism.

    Similar to ColumnParallelLinear, but the weight matrix is concatenated
    along the output dimension. When the weight matrix is loaded, the
    different partitions are sharded separately.

    Args:
        input_size: input dimension of the linear layer.
        output_sizes: list of output dimensions of the linear layer.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make the output
                       available to all GPUs, otherwise, every GPU will have
                       its own output.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
616
        quant_config: Quantization configure.
617
618
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
619
        return_bias: If true, return bias together with outputs in forward pass.
620
621
        disable_tp: If true, all weights matrix won't be sharded, this layer
                    will be treated as a "Replicated" MergedLinear.
622
623
    """

624
625
626
627
628
629
630
    def __init__(
        self,
        input_size: int,
        output_sizes: list[int],
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
631
632
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
633
634
635
        prefix: str = "",
        *,
        return_bias: bool = True,
636
        disable_tp: bool = False,
637
    ):
638
        self.output_sizes = output_sizes
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0

        assert all(output_size % self.tp_size == 0 for output_size in output_sizes)
        super().__init__(
            input_size=input_size,
            output_size=sum(output_sizes),
            bias=bias,
            gather_output=gather_output,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            quant_config=quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
James Fleming's avatar
James Fleming committed
655

656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
    def validate_shard_id(self, loaded_shard_id: int | tuple[int, ...] | None):
        if loaded_shard_id is None:
            return
        if isinstance(loaded_shard_id, tuple):
            for idx in loaded_shard_id:
                if not (0 <= idx < len(self.output_sizes)):
                    raise ValueError(
                        f"Shard id index {idx} should be between 0 and "
                        f"{len(self.output_sizes) - 1}. Got shard id {loaded_shard_id}."
                    )
            if len(loaded_shard_id) > 1 and any(
                b - a != 1 for a, b in zip(loaded_shard_id[:-1], loaded_shard_id[1:])
            ):
                raise ValueError(
                    "Shard id with multiple indices should be consecutive. "
                    f"Got shard id {loaded_shard_id}."
                )
            return
        elif isinstance(loaded_shard_id, int):
            if loaded_shard_id < 0 or loaded_shard_id >= len(self.output_sizes):
                raise ValueError(
                    f"Shard id should be between 0 and {len(self.output_sizes) - 1}. "
                    f"Got shard id {loaded_shard_id}."
                )
            return
        raise ValueError("This line should not be reached")

683
684
685
686
    def weight_loader(
        self,
        param: Parameter,
        loaded_weight: torch.Tensor,
687
        loaded_shard_id: tuple[int, ...] | int | None = None,
688
    ):
689
        self.validate_shard_id(loaded_shard_id)
690
691
692
693
        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
694
695
696
697
698
699
        if isinstance(loaded_shard_id, tuple) and (
            is_gguf_weight or is_gguf_weight_type
        ):
            raise NotImplementedError(
                "Shard id with multiple indices is not supported for GGUF."
            )
700
        if is_gguf_weight_type:
701
702
703
704
705
            if loaded_shard_id is not None:
                param.data[loaded_shard_id].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
                param.shard_weight_type = {
706
                    i: loaded_weight.item() for i, _ in enumerate(self.output_sizes)
707
                }
708
709
            return

710
711
        if is_gguf_weight:
            output_dim = getattr(param, "output_dim", None)
712
713
            shard_size = loaded_weight.size(output_dim) // self.tp_size
            start_idx = self.tp_rank * shard_size
714

715
            if loaded_shard_id is not None:
716
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
717
718
719
720
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
721

722
723
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
724
725
        # Special case for per-tensor scale to load scalar into fused array.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
726

727
        if loaded_shard_id is None or isinstance(loaded_shard_id, tuple):
728
729
            # Loaded weight is already fused on disk (mlp).
            # (e.g., Phi-3's gate_up_proj).
730
            if output_dim is None:
731
                if needs_scalar_to_array:
732
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
733
734
                        param_data, loaded_weight, 0
                    )
735

736
737
738
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
739
740
741
742
743
744

            output_sizes = (
                self.output_sizes[loaded_shard_id[0] : loaded_shard_id[-1] + 1]
                if loaded_shard_id is not None
                else self.output_sizes
            )
745
            current_shard_offset = 0
746
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
747
748
749
750
751
            if use_bitsandbytes_4bit and isinstance(loaded_shard_id, tuple):
                raise NotImplementedError(
                    "Shard id with multiple indices is not supported "
                    "for BNB quantization yet."
                )
752
            shard_offsets: list[tuple[int, int, int]] = []
753
            for i, output_size in enumerate(output_sizes):
754
755
756
757
                shard_offsets.append((i, current_shard_offset, output_size))
                current_shard_offset += output_size
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
758
                # Special case for Quantization.
759
760
761
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
762
763
                    shard_size = shard_size // param.packed_factor
                    shard_offset = shard_offset // param.packed_factor
764
                    # Special case for Marlin.
765
                    shard_size, shard_offset = adjust_marlin_shard(
766
767
                        param, shard_size, shard_offset
                    )
768

769
                if use_bitsandbytes_4bit:
770
771
772
773
774
775
776
                    index = list(itertools.accumulate([0] + self.output_sizes))
                    orig_offsets = {
                        str(i): (index[i], size)
                        for i, size in enumerate(self.output_sizes)
                    }
                    orig_offsets["total"] = (self.output_size, 0)
                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
777
778
                        param, orig_offsets, str(shard_id)
                    )
779

780
                loaded_weight_shard = loaded_weight.narrow(
781
782
                    output_dim, shard_offset, shard_size
                )
783
784
785
786
787
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id < len(self.output_sizes)
        if output_dim is not None:
788
789
            shard_offset = sum(self.output_sizes[:loaded_shard_id])
            shard_size = self.output_sizes[loaded_shard_id]
790
791
            shard_offset //= self.tp_size
            shard_size //= self.tp_size
792
793
794
795
796
797
798

            if isinstance(param, BlockQuantScaleParameter):
                weight_block_size = getattr(self, "weight_block_size", None)
                shard_size, shard_offset = adjust_block_scale_shard(
                    weight_block_size, shard_size, shard_offset
                )

799
            # Special case for quantization.
800
801
802
803
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
804
805
                shard_size = shard_size // param.packed_factor
                shard_offset = shard_offset // param.packed_factor
806
                # Special case for Marlin.
807
                shard_size, shard_offset = adjust_marlin_shard(
808
809
                    param, shard_size, shard_offset
                )
810

811
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
812
813
814
815
816
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

817
            if use_bitsandbytes_4bit:
818
                shard_size = loaded_weight.shape[output_dim]
819
                shard_offset = loaded_weight.shape[output_dim] * loaded_shard_id
820

821
            param_data = param_data.narrow(output_dim, shard_offset, shard_size)
822
            start_idx = self.tp_rank * shard_size
823
            if not is_sharded_weight:
824
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
825
826
827
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
828
829
                param_data, loaded_weight, loaded_shard_id
            )
830

831
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
832
833
834
835
836
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "MergedColumnParallelLinear, assume the weight is "
837
838
                    "the same for all partitions."
                )
839

840
841
842
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

843
    def _load_fused_module_from_checkpoint(
844
845
846
847
        self,
        param: BasevLLMParameter,
        loaded_weight: torch.Tensor,
        output_sizes: list[int] | None = None,
848
    ):
849
850
851
        """
        Handle special case for models where MLP layers are already
        fused on disk. In this case, we have no shard id. This function
852
        determines the shard id by splitting these layers and then calls
853
854
855
856
857
858
859
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """

        current_shard_offset = 0
860
        shard_offsets: list[tuple[int, int, int]] = []
861
862
        output_sizes = output_sizes or self.output_sizes
        for i, output_size in enumerate(output_sizes):
863
864
865
866
867
868
869
            shard_offsets.append((i, current_shard_offset, output_size))
            current_shard_offset += output_size

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
870
871
872
873
874
875
876
877
878
879
880
            if (
                isinstance(param, (PackedColumnParameter, PackedvLLMParameter))
                and param.packed_dim == param.output_dim
            ):
                shard_size, shard_offset = param.adjust_shard_indexes_for_packing(
                    shard_size=shard_size, shard_offset=shard_offset
                )

            loaded_weight_shard = loaded_weight.narrow(
                param.output_dim, shard_offset, shard_size
            )
881
882
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

883
884
885
886
    def weight_loader_v2(
        self,
        param: BasevLLMParameter,
        loaded_weight: torch.Tensor,
887
        loaded_shard_id: tuple[int, ...] | int | None = None,
888
    ):
889
        self.validate_shard_id(loaded_shard_id)
890
        if loaded_shard_id is None or isinstance(loaded_shard_id, tuple):
891
            if isinstance(param, PerTensorScaleParameter):
892
                param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0)
893
                return
894
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
895
                param.load_merged_column_weight(loaded_weight=loaded_weight)
896
                return
897
898
899
900
901
902
903
904
905
906
907
            output_sizes = (
                [self.output_sizes[idx] for idx in loaded_shard_id]
                if loaded_shard_id
                else None
            )
            if isinstance(param, BlockQuantScaleParameter):
                weight_block_size = getattr(self, "weight_block_size", None)
                output_sizes = [
                    adjust_block_scale_shard(weight_block_size, size, 0)[0]
                    for size in (output_sizes or self.output_sizes)
                ]
908
            # TODO: @dsikka - move to parameter.py
909
910
911
            self._load_fused_module_from_checkpoint(
                param, loaded_weight, output_sizes=output_sizes
            )
912
913
914
915
            return

        assert loaded_shard_id < len(self.output_sizes)

916
917
        shard_offset = sum(self.output_sizes[:loaded_shard_id])
        shard_size = self.output_sizes[loaded_shard_id]
918
919
        shard_offset //= self.tp_size
        shard_size //= self.tp_size
920

921
        if isinstance(param, BlockQuantScaleParameter):
922
923
924
            weight_block_size = getattr(self, "weight_block_size", None)
            shard_size, shard_offset = adjust_block_scale_shard(
                weight_block_size, shard_size, shard_offset
925
            )
926

927
928
929
930
931
932
933
        param.load_merged_column_weight(
            loaded_weight=loaded_weight,
            shard_id=loaded_shard_id,
            shard_offset=shard_offset,
            shard_size=shard_size,
            tp_rank=self.tp_rank,
        )
934

935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956

class QKVParallelLinear(ColumnParallelLinear):
    """Linear layers for the attention's QKV transformation.

    Linear layers for the linear transformation of the query, key, and value
    vectors in the attention layer. The weight matrix is concatenated along
    the output dimension. The layer is parallelized along the head dimension.
    When the number of key/value heads is smaller than the number of query
    heads (e.g., multi-query/grouped-query attention), the key/value head may
    be replicated while the query heads are partitioned.

    Args:
        hidden_size: input hidden state size of the transformer.
        head_size: size of each attention head.
        total_num_heads: total number of attention query heads.
        total_num_kv_heads: total number of attention key/value heads. If
                            None, assume total_num_kv_heads = total_num_heads.
        bias: If true, add bias.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
957
        quant_config: Quantization configure.
958
959
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
960
        return_bias: If true, return bias together with outputs in forward pass.
961
        disable_tp: If true, weights matrix won't be sharded through tp rank.
962
963
    """

964
965
966
967
968
    def __init__(
        self,
        hidden_size: int,
        head_size: int,
        total_num_heads: int,
969
        total_num_kv_heads: int | None = None,
970
971
        bias: bool = True,
        skip_bias_add: bool = False,
972
973
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
974
975
976
        prefix: str = "",
        *,
        return_bias: bool = True,
977
        disable_tp: bool = False,
978
        v_head_size: int | None = None,
979
    ):
980
981
        self.hidden_size = hidden_size
        self.head_size = head_size
982
        self.v_head_size = v_head_size if v_head_size is not None else head_size
983
984
985
986
987
        self.total_num_heads = total_num_heads
        if total_num_kv_heads is None:
            total_num_kv_heads = total_num_heads
        self.total_num_kv_heads = total_num_kv_heads
        # Divide the weight matrix along the last dimension.
988
        tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
989
990
991
        self.num_heads = divide(self.total_num_heads, tp_size)
        if tp_size >= self.total_num_kv_heads:
            self.num_kv_heads = 1
992
            self.num_kv_head_replicas = divide(tp_size, self.total_num_kv_heads)
993
994
995
996
        else:
            self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
            self.num_kv_head_replicas = 1
        input_size = self.hidden_size
997
        output_size = (
998
999
1000
1001
            self.num_heads * self.head_size
            + self.num_kv_heads * self.head_size
            + self.num_kv_heads * self.v_head_size
        ) * tp_size
1002
1003
1004
        self.output_sizes = [
            self.num_heads * self.head_size * tp_size,  # q_proj
            self.num_kv_heads * self.head_size * tp_size,  # k_proj
1005
            self.num_kv_heads * self.v_head_size * tp_size,  # v_proj
James Fleming's avatar
James Fleming committed
1006
1007
        ]

1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
        super().__init__(
            input_size=input_size,
            output_size=output_size,
            bias=bias,
            gather_output=False,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            quant_config=quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
1020

1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
    def validate_shard_id(self, loaded_shard_id: str | None):
        if loaded_shard_id is None:
            return
        if isinstance(loaded_shard_id, str):
            if loaded_shard_id not in ["q", "k", "v"]:
                raise ValueError(
                    "Shard id for QKVParallelLinear should be 'q', 'k', or 'v', "
                    f"got shard id {loaded_shard_id}."
                )
            return
        raise ValueError("This line should not be reached")

1033
1034
1035
1036
1037
    def _get_shard_offset_mapping(self, loaded_shard_id: str):
        shard_offset_mapping = {
            "q": 0,
            "k": self.num_heads * self.head_size,
            "v": (self.num_heads + self.num_kv_heads) * self.head_size,
1038
1039
            "total": (self.num_heads + self.num_kv_heads) * self.head_size
            + self.num_kv_heads * self.v_head_size,
1040
1041
1042
1043
1044
1045
1046
        }
        return shard_offset_mapping.get(loaded_shard_id)

    def _get_shard_size_mapping(self, loaded_shard_id: str):
        shard_size_mapping = {
            "q": self.num_heads * self.head_size,
            "k": self.num_kv_heads * self.head_size,
1047
            "v": self.num_kv_heads * self.v_head_size,
1048
1049
1050
        }
        return shard_size_mapping.get(loaded_shard_id)

1051
1052
1053
    def _load_fused_module_from_checkpoint(
        self, param: BasevLLMParameter, loaded_weight: torch.Tensor
    ):
1054
        """
1055
        Handle special case for models where QKV layers are already
1056
        fused on disk. In this case, we have no shard id. This function
1057
        determines the shard id by splitting these layers and then calls
1058
1059
1060
1061
1062
1063
1064
1065
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """
        shard_offsets = [
            # (shard_id, shard_offset, shard_size)
            ("q", 0, self.total_num_heads * self.head_size),
1066
1067
1068
1069
1070
1071
1072
1073
            (
                "k",
                self.total_num_heads * self.head_size,
                self.total_num_kv_heads * self.head_size,
            ),
            (
                "v",
                (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
1074
                self.total_num_kv_heads * self.v_head_size,
1075
            ),
1076
1077
1078
1079
1080
1081
        ]

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
            if (
                isinstance(param, (PackedColumnParameter, PackedvLLMParameter))
                and param.packed_dim == param.output_dim
            ):
                shard_size, shard_offset = param.adjust_shard_indexes_for_packing(
                    shard_size=shard_size, shard_offset=shard_offset
                )

            loaded_weight_shard = loaded_weight.narrow(
                param.output_dim, shard_offset, shard_size
            )
1093
1094
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

1095
1096
1097
1098
    def weight_loader_v2(
        self,
        param: BasevLLMParameter,
        loaded_weight: torch.Tensor,
1099
        loaded_shard_id: str | None = None,
1100
    ):
1101
        self.validate_shard_id(loaded_shard_id)
1102
        if loaded_shard_id is None:  # special case for certain models
1103
            if isinstance(param, PerTensorScaleParameter):
1104
1105
1106
                param.load_qkv_weight(
                    loaded_weight=loaded_weight, shard_id=0, tp_rank=self.tp_rank
                )
1107
                return
1108
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
1109
                param.load_qkv_weight(loaded_weight=loaded_weight, tp_rank=self.tp_rank)
1110
                return
1111
            # TODO: @dsikka - move to parameter.py
1112
1113
1114
1115
1116
1117
1118
1119
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id in ["q", "k", "v"]

        shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
        shard_size = self._get_shard_size_mapping(loaded_shard_id)

1120
        if isinstance(param, BlockQuantScaleParameter):
1121
1122
1123
1124
            weight_block_size = getattr(self, "weight_block_size", None)
            shard_size, shard_offset = adjust_block_scale_shard(
                weight_block_size, shard_size, shard_offset
            )
1125

1126
1127
1128
1129
1130
1131
1132
1133
        param.load_qkv_weight(
            loaded_weight=loaded_weight,
            num_heads=self.num_kv_head_replicas,
            shard_id=loaded_shard_id,
            shard_offset=shard_offset,
            shard_size=shard_size,
            tp_rank=self.tp_rank,
        )
1134

1135
1136
1137
1138
    def weight_loader(
        self,
        param: Parameter,
        loaded_weight: torch.Tensor,
1139
        loaded_shard_id: str | None = None,
1140
    ):
1141
        self.validate_shard_id(loaded_shard_id)
1142
1143
1144
1145
        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
1146
        if is_gguf_weight_type:
1147
            idx_map = {"q": 0, "k": 1, "v": 2}
1148
1149
1150
1151
            if loaded_shard_id is not None:
                param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
1152
                param.shard_weight_type = {k: loaded_weight.item() for k in idx_map}
1153
1154
            return

1155
1156
        if is_gguf_weight:
            output_dim = getattr(param, "output_dim", None)
1157
1158
            shard_size = loaded_weight.size(output_dim) // self.tp_size
            start_idx = self.tp_rank * shard_size
1159

1160
            if loaded_shard_id is not None:
1161
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
1162
1163
1164
1165
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
1166

1167
1168
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
1169

1170
1171
        # Special case for per-tensor scales in fused case.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
1172

1173
        if loaded_shard_id is None:
1174
1175
            # Loaded weight is already fused on disk (qkv).
            # (e.g., Phi-3's qkv_proj).
1176
            if output_dim is None:
1177
                if needs_scalar_to_array:
1178
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
1179
1180
                        param_data, loaded_weight, 0
                    )
1181

1182
1183
1184
1185
1186
1187
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            shard_offsets = [
                # (shard_id, shard_offset, shard_size)
                ("q", 0, self.total_num_heads * self.head_size),
1188
1189
1190
1191
1192
1193
1194
1195
                (
                    "k",
                    self.total_num_heads * self.head_size,
                    self.total_num_kv_heads * self.head_size,
                ),
                (
                    "v",
                    (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
1196
                    self.total_num_kv_heads * self.v_head_size,
1197
                ),
1198
            ]
1199
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1200

1201
1202
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
1203
                # Special case for Quantized Weights.
1204
1205
1206
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
1207
1208
                    shard_size = shard_size // param.packed_factor
                    shard_offset = shard_offset // param.packed_factor
1209

1210
                    # Special case for Marlin.
1211
                    shard_size, shard_offset = adjust_marlin_shard(
1212
1213
                        param, shard_size, shard_offset
                    )
1214

1215
1216
1217
                if use_bitsandbytes_4bit:
                    orig_qkv_offsets = {
                        "q": (0, self.total_num_heads * self.head_size),
1218
1219
1220
1221
1222
1223
1224
                        "k": (
                            self.total_num_heads * self.head_size,
                            self.total_num_kv_heads * self.head_size,
                        ),
                        "v": (
                            (self.total_num_heads + self.total_num_kv_heads)
                            * self.head_size,
1225
                            self.total_num_kv_heads * self.v_head_size,
1226
1227
                        ),
                        "total": (
1228
1229
1230
                            (self.total_num_heads + self.total_num_kv_heads)
                            * self.head_size
                            + self.total_num_kv_heads * self.v_head_size,
1231
1232
                            0,
                        ),
1233
1234
1235
                    }

                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
1236
1237
                        param, orig_qkv_offsets, shard_id
                    )
1238

1239
                loaded_weight_shard = loaded_weight.narrow(
1240
1241
                    output_dim, shard_offset, shard_size
                )
1242
1243
1244
1245
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id in ["q", "k", "v"]
1246
1247

        # If output dim is defined, use the default loading process.
1248
1249
1250
1251
1252
1253
1254
1255
        if output_dim is not None:
            if loaded_shard_id == "q":
                shard_offset = 0
                shard_size = self.num_heads * self.head_size
            elif loaded_shard_id == "k":
                shard_offset = self.num_heads * self.head_size
                shard_size = self.num_kv_heads * self.head_size
            elif loaded_shard_id == "v":
1256
                shard_offset = (self.num_heads + self.num_kv_heads) * self.head_size
1257
                shard_size = self.num_kv_heads * self.v_head_size
1258
1259
1260
1261
1262
1263
1264

            if isinstance(param, BlockQuantScaleParameter):
                weight_block_size = getattr(self, "weight_block_size", None)
                shard_size, shard_offset = adjust_block_scale_shard(
                    weight_block_size, shard_size, shard_offset
                )

1265
            # Special case for Quantized Weights.
1266
1267
1268
1269
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
1270
1271
                shard_size = shard_size // param.packed_factor
                shard_offset = shard_offset // param.packed_factor
1272

1273
                # Special case for Marlin.
1274
                shard_size, shard_offset = adjust_marlin_shard(
1275
1276
                    param, shard_size, shard_offset
                )
1277

1278
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1279
1280
1281
1282
1283
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

1284
            if use_bitsandbytes_4bit:
1285
1286
                orig_qkv_offsets = {
                    "q": (0, self.num_heads * self.head_size),
1287
1288
1289
1290
1291
1292
                    "k": (
                        self.num_heads * self.head_size,
                        self.num_kv_heads * self.head_size,
                    ),
                    "v": (
                        (self.num_heads + self.num_kv_heads) * self.head_size,
1293
                        self.num_kv_heads * self.v_head_size,
1294
1295
                    ),
                    "total": (
1296
1297
                        (self.num_heads + self.num_kv_heads) * self.head_size
                        + self.num_kv_heads * self.v_head_size,
1298
1299
                        0,
                    ),
1300
                }
1301
                shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
1302
1303
                    param, orig_qkv_offsets, loaded_shard_id
                )
1304

1305
            param_data = param_data.narrow(output_dim, shard_offset, shard_size)
1306
            if loaded_shard_id == "q":
1307
                shard_rank = self.tp_rank
1308
            else:
1309
1310
                shard_rank = self.tp_rank // self.num_kv_head_replicas
            start_idx = shard_rank * shard_size
1311

1312
            if not is_sharded_weight:
1313
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
1314

1315
1316
1317
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
1318
1319
                param_data, loaded_weight, loaded_shard_id
            )
1320
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
1321
1322
1323
1324
1325
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "QKVParallelLinear, assume the weight is the same "
1326
1327
                    "for all partitions."
                )
1328

1329
1330
1331
1332
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)


1333
# --8<-- [start:row_parallel_linear]
1334
@PluggableLayer.register("row_parallel_linear")
1335
class RowParallelLinear(LinearBase):
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
    """Linear layer with row parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its first dimension and X along its second dimension as:
               -   -
              | A_1 |
              | .   |
          A = | .   |        X = [X_1, ..., X_p]
              | .   |
              | A_p |
               -   -
    Arguments:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias. Note that bias is not parallelized.
        input_is_parallel: If true, we assume that the input is already
                           split across the GPUs and we do not split
                           again.
        skip_bias_add: This was added to enable performance optimization where
                       bias can be fused with other element-wise operations.
                       We skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
1358
1359
1360
        reduce_results: If true, call all-reduce on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y = X_iA_i
1361
        quant_config: Quantization configure.
1362
1363
1364
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.down_proj)
        return_bias: If true, return bias together with outputs in forward pass.
1365
        disable_tp: If true, weights matrix won't be sharded through tp rank.
1366
1367
    """

1368
1369
    # --8<-- [end:row_parallel_linear]

1370
1371
1372
1373
1374
1375
1376
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        input_is_parallel: bool = True,
        skip_bias_add: bool = False,
1377
        params_dtype: torch.dtype | None = None,
1378
        reduce_results: bool = True,
1379
        quant_config: QuantizationConfig | None = None,
1380
1381
1382
        prefix: str = "",
        *,
        return_bias: bool = True,
1383
        disable_tp: bool = False,
1384
    ):
1385
        # Divide the weight matrix along the first dimension.
1386
1387
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
1388
1389
1390
1391
        self.input_size_per_partition = divide(input_size, self.tp_size)
        self.output_size_per_partition = output_size
        self.output_partition_sizes = [output_size]

1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
        super().__init__(
            input_size,
            output_size,
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
1402

1403
1404
1405
        self.input_is_parallel = input_is_parallel
        self.reduce_results = reduce_results

1406
        assert self.quant_method is not None
1407
1408
1409
        self.quant_method.create_weights(
            layer=self,
            input_size_per_partition=self.input_size_per_partition,
1410
            output_partition_sizes=self.output_partition_sizes,
1411
1412
1413
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
1414
            weight_loader=(
1415
1416
1417
1418
1419
                self.weight_loader_v2
                if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
                else self.weight_loader
            ),
        )
1420
        if not reduce_results and (bias and not skip_bias_add):
1421
1422
1423
1424
            raise ValueError(
                "When not reduce the results, adding bias to the "
                "results can lead to incorrect results"
            )
1425
1426

        if bias:
1427
1428
1429
1430
1431
1432
1433
1434
            self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype))
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
1435
1436
        else:
            self.register_parameter("bias", None)
1437
        self.update_param_tp_status()
1438
1439
1440

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        input_dim = getattr(param, "input_dim", None)
1441
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1442
1443
1444
1445
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
1446
1447
1448
1449
1450
1451
1452
1453
1454

        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
1455
1456
            weight_shape = list(loaded_weight.shape)
            if input_dim:
1457
                weight_shape[input_dim] = weight_shape[input_dim] // self.tp_size
1458
            param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
1459

1460
        param_data = param.data
1461
        if input_dim is not None and not is_sharded_weight:
1462
            shard_size = param_data.shape[input_dim]
1463
            start_idx = self.tp_rank * shard_size
1464
            loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)
1465

1466
1467
1468
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
1469
1470
            loaded_weight = loaded_weight.reshape(1)

1471
1472
1473
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

1474
    def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor):
1475
1476
1477
1478
1479
1480
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)

1481
1482
        param.load_row_parallel_weight(loaded_weight=loaded_weight)

1483
    def forward(
1484
1485
        self,
        input_,
1486
    ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
1487
1488
1489
1490
        if self.input_is_parallel:
            input_parallel = input_
        else:
            splitted_input = split_tensor_along_last_dim(
1491
1492
                input_, num_partitions=self.tp_size
            )
1493
            input_parallel = splitted_input[self.tp_rank].contiguous()
1494
1495

        # Matrix multiply.
1496
        assert self.quant_method is not None
1497
1498
1499
        # Only fuse bias add into GEMM for rank 0 (this ensures that
        # bias will not get added more than once in TP>1 case)
        bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
1500
1501
        output_parallel = self.quant_method.apply(self, input_parallel, bias_)

1502
        if self.reduce_results and self.tp_size > 1:
1503
            output = tensor_model_parallel_all_reduce(output_parallel)
1504
        else:
1505
1506
            output = output_parallel

1507
1508
        if not self.return_bias:
            return output
1509
        output_bias = self.bias if self.skip_bias_add else None
1510
        return output, output_bias
1511
1512

    def extra_repr(self) -> str:
1513
        s = f"in_features={self.input_size_per_partition}"
1514
1515
1516
1517
1518
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={self.tp_size}"
        s += f", reduce_results={self.reduce_results}"
        return s