linear.py 60.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
from abc import abstractmethod
6
7

import torch
8
from torch.nn.parameter import Parameter, UninitializedParameter
9

10
import vllm.envs as envs
11
12
13
14
15
16
17
18
from vllm.distributed import (
    divide,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
    split_tensor_along_last_dim,
    tensor_model_parallel_all_gather,
    tensor_model_parallel_all_reduce,
)
19
from vllm.logger import init_logger
20
from vllm.model_executor.custom_op import PluggableLayer
21
22
23
from vllm.model_executor.layers.batch_invariant import (
    linear_batch_invariant,
)
24
from vllm.model_executor.layers.quantization.base_config import (
25
26
27
    QuantizationConfig,
    QuantizeMethodBase,
)
28
29
30
from vllm.model_executor.layers.utils import (
    dispatch_unquantized_gemm,
)
31
32
33
34
35
36
37
38
39
from vllm.model_executor.parameter import (
    BasevLLMParameter,
    BlockQuantScaleParameter,
    ModelWeightParameter,
    PackedColumnParameter,
    PackedvLLMParameter,
    PerTensorScaleParameter,
    RowvLLMParameter,
)
40
from vllm.model_executor.utils import set_weight_attrs
41
from vllm.platforms import current_platform
42
43
44

logger = init_logger(__name__)

45
WEIGHT_LOADER_V2_SUPPORTED = [
46
    "UnquantizedLinearMethod",
47
    "CompressedTensorsLinearMethod",
48
    "CompressedTensorsLinearTransformMethod",
49
50
51
52
53
54
55
56
57
58
    "AWQMarlinLinearMethod",
    "AWQLinearMethod",
    "GPTQMarlinLinearMethod",
    "Fp8LinearMethod",
    "MarlinLinearMethod",
    "GPTQMarlin24LinearMethod",
    "TPUInt8LinearMethod",
    "GPTQLinearMethod",
    "FBGEMMFp8LinearMethod",
    "ModelOptFp8LinearMethod",
59
60
    "ModelOptFp8PcPtLinearMethod",
    "ModelOptFp8PbWoLinearMethod",
61
62
    "QuarkLinearMethod",
    "ModelOptNvFp4LinearMethod",
63
]
64

65

66
67
68
69
70
71
def register_weight_loader_v2_supported_method(cls):
    """Decorator to register a LinearMethod as supporting weight_loader_v2."""
    WEIGHT_LOADER_V2_SUPPORTED.append(cls.__name__)
    return cls


72
73
74
75
76
77
def adjust_marlin_shard(
    param: Parameter,
    shard_size: int,
    shard_offset: int,
) -> tuple[int, int]:
    marlin_tile_size: int | None = getattr(param, "marlin_tile_size", None)
78
79
80
81
82
83
    if marlin_tile_size is None:
        return shard_size, shard_offset

    return shard_size * marlin_tile_size, shard_offset * marlin_tile_size


84
85
86
87
88
def adjust_block_scale_shard(
    weight_block_size: tuple[int, ...] | None,
    shard_size: int,
    shard_offset: int,
) -> tuple[int, int]:
89
90
91
92
93
94
95
    assert weight_block_size is not None
    block_n = weight_block_size[0]
    shard_offset = (shard_offset + block_n - 1) // block_n
    shard_size = (shard_size + block_n - 1) // block_n
    return shard_size, shard_offset


96
def adjust_bitsandbytes_4bit_shard(
97
98
99
    param: Parameter,
    shard_offsets: dict[str, tuple[int, int]],
    loaded_shard_id: str,
100
) -> tuple[int, int]:
101
102
    """Adjust the quantization offsets and sizes for BitsAndBytes sharding."""

103
104
    total, _ = shard_offsets["total"]
    orig_offset, orig_size = shard_offsets[loaded_shard_id]
105
106
107
108
109
110
111
112

    quantized_total = param.data.shape[0]
    quantized_offset = orig_offset * quantized_total // total
    quantized_size = orig_size * quantized_total // total

    return quantized_size, quantized_offset


113
114
115
116
117
def adjust_scalar_to_fused_array(
    param_data: torch.Tensor,
    loaded_weight: torch.Tensor,
    shard_id: int | str,
) -> tuple[torch.Tensor, torch.Tensor]:
118
119
    """For fused modules (QKV and MLP) we have an array of length
    N that holds 1 scale for each "logical" matrix. So the param
120
121
    is an array of length N. The loaded_weight corresponds to
    one of the shards on disk. Here, we slice the param based on
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
    the shard_id for loading.
    """
    qkv_idxs = {"q": 0, "k": 1, "v": 2}

    if isinstance(shard_id, str):
        shard_id = qkv_idxs[shard_id]
    elif not isinstance(shard_id, int):
        raise ValueError(f"Unknown Shard Id {shard_id}")

    # AutoFP8 scales do not have a shape
    # compressed-tensors scales do have a shape
    if len(loaded_weight.shape) != 0:
        assert loaded_weight.shape[0] == 1
        loaded_weight = loaded_weight[0]

137
    return param_data[shard_id], loaded_weight
138
139


140
class LinearMethodBase(QuantizeMethodBase):
141
142
143
    """Base class for different (maybe quantized) linear methods."""

    @abstractmethod
144
145
146
147
148
149
150
151
152
153
154
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        """Create weights for a linear layer.
155
           The weights will be set as attributes of the layer.
156

157
158
159
        Args:
            layer: The layer that is using the LinearMethodBase factory.
            input_size_per_partition: Size of the weight input dim on rank X.
160
            output_partition_sizes: Sizes of the output dim of each logical
161
162
163
164
165
166
                weight on rank X. E.g., output_partition_sizes for QKVLinear
                is a list contains the width of Wq, Wk, Wv on rank X.
            input_size: Size of the input dim of the weight across all ranks.
            output_size: Size of the output dim of the weight across all ranks.
            params_dtype: Datatype of the parameters.
        """
167
168
169
        raise NotImplementedError

    @abstractmethod
170
171
172
173
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
174
        bias: torch.Tensor | None = None,
175
    ) -> torch.Tensor:
176
177
        """Apply the weights in layer to the input tensor.
        Expects create_weights to have been called before on the layer."""
178
179
180
181
        raise NotImplementedError


class UnquantizedLinearMethod(LinearMethodBase):
182
    """Linear method without quantization."""
183

184
185
186
187
188
189
190
191
192
193
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
194
195
196
197
        # This method creates unquantized linear weights.
        # The weights are not quantized, and they are not sharded.
        # The amount of memory allocated for the weights is
        # sum(output_partition_sizes) * input_size_per_partition.
198
199
200
201
202
203
204
205
206
207
208
        weight_loader = extra_weight_attrs.pop("weight_loader")
        weight = ModelWeightParameter(
            data=torch.empty(
                sum(output_partition_sizes),
                input_size_per_partition,
                dtype=params_dtype,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
209

210
211
        layer.register_parameter("weight", weight)
        set_weight_attrs(weight, extra_weight_attrs)
212

213
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
214
        if current_platform.is_cpu():
215
            from vllm.model_executor.layers.utils import dispatch_cpu_unquantized_gemm
216

217
            dispatch_cpu_unquantized_gemm(layer, remove_weight=True)
218

219
220
221
222
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
223
        bias: torch.Tensor | None = None,
224
    ) -> torch.Tensor:
225
        if envs.VLLM_BATCH_INVARIANT and current_platform.is_cuda_alike():
226
            return linear_batch_invariant(x, layer.weight, bias)
227
        return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
228
229


230
class LinearBase(PluggableLayer):
231
    """Base linear layer.
232
233
234
235
236
237

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
238
        quant_config: Quantization configure.
239
        prefix: Prefix for parameter names.
240
        return_bias: If true, return bias together with outputs in forward pass.
241
        disable_tp: If true, tensor parallelism will be disabled for this layer.
242
243
244
245
246
247
248
    """

    def __init__(
        self,
        input_size: int,
        output_size: int,
        skip_bias_add: bool = False,
249
250
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
251
        prefix: str = "",
252
253
        *,
        return_bias: bool = True,
254
        disable_tp: bool = False,
255
256
257
258
259
260
261
262
263
264
    ):
        super().__init__()

        # Keep input parameters
        self.input_size = input_size
        self.output_size = output_size
        self.skip_bias_add = skip_bias_add
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype
265
266
        self.quant_config = quant_config
        self.prefix = prefix
267
        self.allow_fp8_block_shape_mismatch = False
268
        if quant_config is None:
269
            self.quant_method: QuantizeMethodBase | None = UnquantizedLinearMethod()
270
        else:
271
            self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
272
        self.return_bias = return_bias
273
        self.disable_tp = disable_tp
274
275
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
276

277
    def update_param_tp_status(self):
278
279
280
281
        for param in self.parameters():
            if isinstance(param, BasevLLMParameter):
                param.tp_rank = self.tp_rank
                param.tp_size = self.tp_size
282
283


284
# --8<-- [start:replicated_linear]
285
@PluggableLayer.register("replicated_linear")
286
287
288
289
290
291
292
293
294
295
class ReplicatedLinear(LinearBase):
    """Replicated linear layer.

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        bias: If true, add bias.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
        quant_config: Quantization configure.
296
297
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
298
        return_bias: If true, return bias together with outputs in forward pass.
299
        disable_tp: Take no effect for replicated linear layers.
300
301
    """

302
303
    # --8<-- [end:replicated_linear]

304
305
306
307
308
309
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        skip_bias_add: bool = False,
310
311
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
312
313
314
        prefix: str = "",
        *,
        return_bias: bool = True,
315
        disable_tp: bool = False,
316
    ):
317
318
319
320
321
322
        # If MergedReplicatedLinear, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = self.output_sizes
        else:
            self.output_partition_sizes = [output_size]

323
324
325
326
327
328
329
330
331
332
        super().__init__(
            input_size,
            output_size,
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
333

334
335
        # All the linear layer supports quant method.
        assert self.quant_method is not None
336
337
338
339
340
341
342
343
344
        self.quant_method.create_weights(
            self,
            self.input_size,
            self.output_partition_sizes,
            self.input_size,
            self.output_size,
            self.params_dtype,
            weight_loader=self.weight_loader,
        )
345

346
347
        if bias:
            self.bias = Parameter(
348
349
350
351
352
353
354
355
356
                torch.empty(self.output_size, dtype=self.params_dtype)
            )
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
357
358
359
        else:
            self.register_parameter("bias", None)

360
361
362
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        # If the weight on disk does not have a shape, give it one
        # (such scales for AutoFp8).
363
364
365
366
367
368
369
370
371
372
373
        # Special case for GGUF

        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
            param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)

374
375
376
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)

377
378
        assert param.size() == loaded_weight.size(), (
            f"Tried to load weights of size {loaded_weight.size()}"
379
380
            f"to a parameter of size {param.size()}"
        )
381
382
        param.data.copy_(loaded_weight)

383
    def forward(
384
385
        self,
        x: torch.Tensor,
386
    ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
387
        bias = self.bias if not self.skip_bias_add else None
388
        assert self.quant_method is not None
389

390
        output = self.quant_method.apply(self, x, bias)
391

392
393
        if not self.return_bias:
            return output
394
        output_bias = self.bias if self.skip_bias_add else None
395
396
        return output, output_bias

397
398
399
400
401
402
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        return s

403

404
# --8<-- [start:column_parallel_linear]
405
@PluggableLayer.register("column_parallel_linear")
406
class ColumnParallelLinear(LinearBase):
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
    """Linear layer with column parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its second dimension as A = [A_1, ..., A_p].

    Args:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y_i = XA_i
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
423
        quant_config: Quantization configure.
424
        prefix: The name of the layer in the state dict, including all parents
425
                        (e.g. model.layers.0.qkv_proj)
426
427
        return_bias: If true, return bias together with outputs in forward pass.
        disable_tp: If true, weights matrix won't be sharded through tp rank.
428
429
    """

430
431
    # --8<-- [end:column_parallel_linear]

432
433
434
435
436
437
438
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
439
440
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
441
442
443
        prefix: str = "",
        *,
        return_bias: bool = True,
444
        disable_tp: bool = False,
445
    ):
446
        # Divide the weight matrix along the last dimension.
447
448
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
449
450
        self.input_size_per_partition = input_size
        self.output_size_per_partition = divide(output_size, self.tp_size)
451
452
453
454
        self.output_partition_sizes = [self.output_size_per_partition]
        # If QKV or MergedColumn, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = [
455
                divide(output_size, self.tp_size) for output_size in self.output_sizes
456
457
            ]

458
459
460
461
462
463
464
465
466
467
        super().__init__(
            input_size,
            output_size,
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
468

469
        self._maybe_allow_fp8_block_shape_mismatch()
470
471
472
        self.gather_output = gather_output

        assert self.quant_method is not None
473
474
        self.quant_method.create_weights(
            layer=self,
475
            input_size_per_partition=self.input_size_per_partition,
476
477
478
479
            output_partition_sizes=self.output_partition_sizes,
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
480
            weight_loader=(
481
482
483
484
485
                self.weight_loader_v2
                if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
                else self.weight_loader
            ),
        )
486
487
        if bias:
            self.bias = Parameter(
488
489
490
491
492
493
494
495
496
                torch.empty(self.output_size_per_partition, dtype=params_dtype)
            )
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
497
498
        else:
            self.register_parameter("bias", None)
499
        self.update_param_tp_status()
500

501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
    def _maybe_allow_fp8_block_shape_mismatch(self) -> None:
        quant_config = getattr(self, "quant_config", None)
        weight_block = getattr(quant_config, "weight_block_size", None)
        if (
            weight_block is None
            or len(weight_block) < 1
            or len(self.output_partition_sizes) <= 1
        ):
            return

        try:
            block_n = int(weight_block[0])
        except (ValueError, TypeError):
            return

        if block_n <= 0:
            return

        if any(size % block_n != 0 for size in self.output_partition_sizes):
            self.allow_fp8_block_shape_mismatch = True
            logger.debug(
                "Allowing FP8 block shape mismatch for %s (block_n=%d, partitions=%s)",
                getattr(self, "prefix", "<unknown>"),
                block_n,
                self.output_partition_sizes,
            )

528
529
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        output_dim = getattr(param, "output_dim", None)
530

531
532
533
534
535
536
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

537
538
539
540
541
542
543
544
        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
545
546
            final_shape = list(loaded_weight.shape)
            if output_dim is not None:
547
                assert final_shape[output_dim] % self.tp_size == 0
548
                final_shape[output_dim] = final_shape[output_dim] // self.tp_size
549
            param.materialize(final_shape, dtype=loaded_weight.dtype)
550

551
        param_data = param.data
552
        if output_dim is not None and not is_sharded_weight:
553
            shard_size = param_data.shape[output_dim]
554
            start_idx = self.tp_rank * shard_size
555
            loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
556
557
558
559
560

        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)
561

562
563
564
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

565
    def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor):
566
567
568
569
570
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)
571
572
        param.load_column_parallel_weight(loaded_weight=loaded_weight)

573
    def forward(
574
575
        self,
        input_,
576
    ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
577
578
579
        bias = self.bias if not self.skip_bias_add else None

        # Matrix multiply.
580
        assert self.quant_method is not None
581
        output_parallel = self.quant_method.apply(self, input_, bias)
582

583
        if self.gather_output and self.tp_size > 1:
584
585
586
587
            # All-gather across the partitions.
            output = tensor_model_parallel_all_gather(output_parallel)
        else:
            output = output_parallel
588

589
590
        if not self.return_bias:
            return output
591
        output_bias = self.bias if self.skip_bias_add else None
592
593
        return output, output_bias

594
595
596
597
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size_per_partition}"
        s += f", bias={self.bias is not None}"
598
        s += f", tp_size={self.tp_size}"
599
600
601
        s += f", gather_output={self.gather_output}"
        return s

602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620

class MergedColumnParallelLinear(ColumnParallelLinear):
    """Packed linear layers with column parallelism.

    Similar to ColumnParallelLinear, but the weight matrix is concatenated
    along the output dimension. When the weight matrix is loaded, the
    different partitions are sharded separately.

    Args:
        input_size: input dimension of the linear layer.
        output_sizes: list of output dimensions of the linear layer.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make the output
                       available to all GPUs, otherwise, every GPU will have
                       its own output.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
621
        quant_config: Quantization configure.
622
623
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
624
        return_bias: If true, return bias together with outputs in forward pass.
625
626
        disable_tp: If true, all weights matrix won't be sharded, this layer
                    will be treated as a "Replicated" MergedLinear.
627
628
    """

629
630
631
632
633
634
635
    def __init__(
        self,
        input_size: int,
        output_sizes: list[int],
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
636
637
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
638
639
640
        prefix: str = "",
        *,
        return_bias: bool = True,
641
        disable_tp: bool = False,
642
    ):
643
        self.output_sizes = output_sizes
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0

        assert all(output_size % self.tp_size == 0 for output_size in output_sizes)
        super().__init__(
            input_size=input_size,
            output_size=sum(output_sizes),
            bias=bias,
            gather_output=gather_output,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            quant_config=quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
James Fleming's avatar
James Fleming committed
660

661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
    def validate_shard_id(self, loaded_shard_id: int | tuple[int, ...] | None):
        if loaded_shard_id is None:
            return
        if isinstance(loaded_shard_id, tuple):
            for idx in loaded_shard_id:
                if not (0 <= idx < len(self.output_sizes)):
                    raise ValueError(
                        f"Shard id index {idx} should be between 0 and "
                        f"{len(self.output_sizes) - 1}. Got shard id {loaded_shard_id}."
                    )
            if len(loaded_shard_id) > 1 and any(
                b - a != 1 for a, b in zip(loaded_shard_id[:-1], loaded_shard_id[1:])
            ):
                raise ValueError(
                    "Shard id with multiple indices should be consecutive. "
                    f"Got shard id {loaded_shard_id}."
                )
            return
        elif isinstance(loaded_shard_id, int):
            if loaded_shard_id < 0 or loaded_shard_id >= len(self.output_sizes):
                raise ValueError(
                    f"Shard id should be between 0 and {len(self.output_sizes) - 1}. "
                    f"Got shard id {loaded_shard_id}."
                )
            return
        raise ValueError("This line should not be reached")

688
689
690
691
    def weight_loader(
        self,
        param: Parameter,
        loaded_weight: torch.Tensor,
692
        loaded_shard_id: tuple[int, ...] | int | None = None,
693
    ):
694
        self.validate_shard_id(loaded_shard_id)
695
696
697
698
        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
699
700
701
702
703
704
        if isinstance(loaded_shard_id, tuple) and (
            is_gguf_weight or is_gguf_weight_type
        ):
            raise NotImplementedError(
                "Shard id with multiple indices is not supported for GGUF."
            )
705
        if is_gguf_weight_type:
706
707
708
709
710
            if loaded_shard_id is not None:
                param.data[loaded_shard_id].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
                param.shard_weight_type = {
711
                    i: loaded_weight.item() for i, _ in enumerate(self.output_sizes)
712
                }
713
714
            return

715
716
        if is_gguf_weight:
            output_dim = getattr(param, "output_dim", None)
717
718
            shard_size = loaded_weight.size(output_dim) // self.tp_size
            start_idx = self.tp_rank * shard_size
719

720
            if loaded_shard_id is not None:
721
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
722
723
724
725
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
726

727
728
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
729
730
        # Special case for per-tensor scale to load scalar into fused array.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
731

732
        if loaded_shard_id is None or isinstance(loaded_shard_id, tuple):
733
734
            # Loaded weight is already fused on disk (mlp).
            # (e.g., Phi-3's gate_up_proj).
735
            if output_dim is None:
736
                if needs_scalar_to_array:
737
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
738
739
                        param_data, loaded_weight, 0
                    )
740

741
742
743
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
744
745
746
747
748
749

            output_sizes = (
                self.output_sizes[loaded_shard_id[0] : loaded_shard_id[-1] + 1]
                if loaded_shard_id is not None
                else self.output_sizes
            )
750
            current_shard_offset = 0
751
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
752
753
754
755
756
            if (
                use_bitsandbytes_4bit
                and isinstance(loaded_shard_id, tuple)
                and self.tp_size > 1
            ):
757
758
                raise NotImplementedError(
                    "Shard id with multiple indices is not supported "
759
                    "for BNB quantization with TP yet."
760
                )
761
            shard_offsets: list[tuple[int, int, int]] = []
762
            for i, output_size in enumerate(output_sizes):
763
764
765
766
                shard_offsets.append((i, current_shard_offset, output_size))
                current_shard_offset += output_size
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
767
                # Special case for Quantization.
768
769
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
770
771
772
773
774
775
776
                # Add check to adjust the size/offset for FP8 block scales
                if isinstance(param, BlockQuantScaleParameter):
                    weight_block_size = getattr(self, "weight_block_size", None)
                    shard_size, shard_offset = adjust_block_scale_shard(
                        weight_block_size, shard_size, shard_offset
                    )

777
                if packed_dim == output_dim:
778
779
                    shard_size = shard_size // param.packed_factor
                    shard_offset = shard_offset // param.packed_factor
780
                    # Special case for Marlin.
781
                    shard_size, shard_offset = adjust_marlin_shard(
782
783
                        param, shard_size, shard_offset
                    )
784

785
                if use_bitsandbytes_4bit:
786
787
788
789
790
791
792
                    index = list(itertools.accumulate([0] + self.output_sizes))
                    orig_offsets = {
                        str(i): (index[i], size)
                        for i, size in enumerate(self.output_sizes)
                    }
                    orig_offsets["total"] = (self.output_size, 0)
                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
793
794
                        param, orig_offsets, str(shard_id)
                    )
795

796
                loaded_weight_shard = loaded_weight.narrow(
797
798
                    output_dim, shard_offset, shard_size
                )
799
800
801
802
803
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id < len(self.output_sizes)
        if output_dim is not None:
804
805
            shard_offset = sum(self.output_sizes[:loaded_shard_id])
            shard_size = self.output_sizes[loaded_shard_id]
806
807
            shard_offset //= self.tp_size
            shard_size //= self.tp_size
808
809
810
811
812
813
814

            if isinstance(param, BlockQuantScaleParameter):
                weight_block_size = getattr(self, "weight_block_size", None)
                shard_size, shard_offset = adjust_block_scale_shard(
                    weight_block_size, shard_size, shard_offset
                )

815
            # Special case for quantization.
816
817
818
819
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
820
821
                shard_size = shard_size // param.packed_factor
                shard_offset = shard_offset // param.packed_factor
822
                # Special case for Marlin.
823
                shard_size, shard_offset = adjust_marlin_shard(
824
825
                    param, shard_size, shard_offset
                )
826

827
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
828
829
830
831
832
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

833
            if use_bitsandbytes_4bit:
834
835
836
837
838
839
840
841
                index = list(itertools.accumulate([0] + self.output_sizes))
                orig_offsets = {
                    str(i): (index[i], size) for i, size in enumerate(self.output_sizes)
                }
                orig_offsets["total"] = (self.output_size, 0)
                shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
                    param, orig_offsets, str(loaded_shard_id)
                )
842
            param_data = param_data.narrow(output_dim, shard_offset, shard_size)
843
            start_idx = self.tp_rank * shard_size
844
            if not is_sharded_weight:
845
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
846
847
848
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
849
850
                param_data, loaded_weight, loaded_shard_id
            )
851

852
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
853
854
855
856
857
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "MergedColumnParallelLinear, assume the weight is "
858
859
                    "the same for all partitions."
                )
860

861
862
863
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

864
    def _load_fused_module_from_checkpoint(
865
866
867
868
        self,
        param: BasevLLMParameter,
        loaded_weight: torch.Tensor,
        output_sizes: list[int] | None = None,
869
    ):
870
871
872
        """
        Handle special case for models where MLP layers are already
        fused on disk. In this case, we have no shard id. This function
873
        determines the shard id by splitting these layers and then calls
874
875
876
877
878
879
880
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """

        current_shard_offset = 0
881
        shard_offsets: list[tuple[int, int, int]] = []
882
883
        output_sizes = output_sizes or self.output_sizes
        for i, output_size in enumerate(output_sizes):
884
885
886
887
888
889
890
            shard_offsets.append((i, current_shard_offset, output_size))
            current_shard_offset += output_size

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
891
892
893
894
895
896
897
898
899
900
901
            if (
                isinstance(param, (PackedColumnParameter, PackedvLLMParameter))
                and param.packed_dim == param.output_dim
            ):
                shard_size, shard_offset = param.adjust_shard_indexes_for_packing(
                    shard_size=shard_size, shard_offset=shard_offset
                )

            loaded_weight_shard = loaded_weight.narrow(
                param.output_dim, shard_offset, shard_size
            )
902
903
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

904
905
906
907
    def weight_loader_v2(
        self,
        param: BasevLLMParameter,
        loaded_weight: torch.Tensor,
908
        loaded_shard_id: tuple[int, ...] | int | None = None,
909
    ):
910
        self.validate_shard_id(loaded_shard_id)
911
        if loaded_shard_id is None or isinstance(loaded_shard_id, tuple):
912
            if isinstance(param, PerTensorScaleParameter):
913
                param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0)
914
                return
915
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
916
                param.load_merged_column_weight(loaded_weight=loaded_weight)
917
                return
918
919
920
921
922
923
924
925
926
927
928
            output_sizes = (
                [self.output_sizes[idx] for idx in loaded_shard_id]
                if loaded_shard_id
                else None
            )
            if isinstance(param, BlockQuantScaleParameter):
                weight_block_size = getattr(self, "weight_block_size", None)
                output_sizes = [
                    adjust_block_scale_shard(weight_block_size, size, 0)[0]
                    for size in (output_sizes or self.output_sizes)
                ]
929
            # TODO: @dsikka - move to parameter.py
930
931
932
            self._load_fused_module_from_checkpoint(
                param, loaded_weight, output_sizes=output_sizes
            )
933
934
935
936
            return

        assert loaded_shard_id < len(self.output_sizes)

937
938
        shard_offset = sum(self.output_sizes[:loaded_shard_id])
        shard_size = self.output_sizes[loaded_shard_id]
939
940
        shard_offset //= self.tp_size
        shard_size //= self.tp_size
941

942
        if isinstance(param, BlockQuantScaleParameter):
943
944
945
            weight_block_size = getattr(self, "weight_block_size", None)
            shard_size, shard_offset = adjust_block_scale_shard(
                weight_block_size, shard_size, shard_offset
946
            )
947

948
949
950
951
952
953
954
        param.load_merged_column_weight(
            loaded_weight=loaded_weight,
            shard_id=loaded_shard_id,
            shard_offset=shard_offset,
            shard_size=shard_size,
            tp_rank=self.tp_rank,
        )
955

956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977

class QKVParallelLinear(ColumnParallelLinear):
    """Linear layers for the attention's QKV transformation.

    Linear layers for the linear transformation of the query, key, and value
    vectors in the attention layer. The weight matrix is concatenated along
    the output dimension. The layer is parallelized along the head dimension.
    When the number of key/value heads is smaller than the number of query
    heads (e.g., multi-query/grouped-query attention), the key/value head may
    be replicated while the query heads are partitioned.

    Args:
        hidden_size: input hidden state size of the transformer.
        head_size: size of each attention head.
        total_num_heads: total number of attention query heads.
        total_num_kv_heads: total number of attention key/value heads. If
                            None, assume total_num_kv_heads = total_num_heads.
        bias: If true, add bias.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
978
        quant_config: Quantization configure.
979
980
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
981
        return_bias: If true, return bias together with outputs in forward pass.
982
        disable_tp: If true, weights matrix won't be sharded through tp rank.
983
984
    """

985
986
987
988
989
    def __init__(
        self,
        hidden_size: int,
        head_size: int,
        total_num_heads: int,
990
        total_num_kv_heads: int | None = None,
991
992
        bias: bool = True,
        skip_bias_add: bool = False,
993
994
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
995
996
997
        prefix: str = "",
        *,
        return_bias: bool = True,
998
        disable_tp: bool = False,
999
        v_head_size: int | None = None,
1000
    ):
1001
1002
        self.hidden_size = hidden_size
        self.head_size = head_size
1003
        self.v_head_size = v_head_size if v_head_size is not None else head_size
1004
1005
1006
1007
1008
        self.total_num_heads = total_num_heads
        if total_num_kv_heads is None:
            total_num_kv_heads = total_num_heads
        self.total_num_kv_heads = total_num_kv_heads
        # Divide the weight matrix along the last dimension.
1009
        tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
1010
1011
1012
        self.num_heads = divide(self.total_num_heads, tp_size)
        if tp_size >= self.total_num_kv_heads:
            self.num_kv_heads = 1
1013
            self.num_kv_head_replicas = divide(tp_size, self.total_num_kv_heads)
1014
1015
1016
1017
        else:
            self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
            self.num_kv_head_replicas = 1
        input_size = self.hidden_size
1018
        output_size = (
1019
1020
1021
1022
            self.num_heads * self.head_size
            + self.num_kv_heads * self.head_size
            + self.num_kv_heads * self.v_head_size
        ) * tp_size
1023
1024
1025
        self.output_sizes = [
            self.num_heads * self.head_size * tp_size,  # q_proj
            self.num_kv_heads * self.head_size * tp_size,  # k_proj
1026
            self.num_kv_heads * self.v_head_size * tp_size,  # v_proj
James Fleming's avatar
James Fleming committed
1027
1028
        ]

1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
        super().__init__(
            input_size=input_size,
            output_size=output_size,
            bias=bias,
            gather_output=False,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            quant_config=quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
1041

1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
    def validate_shard_id(self, loaded_shard_id: str | None):
        if loaded_shard_id is None:
            return
        if isinstance(loaded_shard_id, str):
            if loaded_shard_id not in ["q", "k", "v"]:
                raise ValueError(
                    "Shard id for QKVParallelLinear should be 'q', 'k', or 'v', "
                    f"got shard id {loaded_shard_id}."
                )
            return
        raise ValueError("This line should not be reached")

1054
1055
1056
1057
1058
    def _get_shard_offset_mapping(self, loaded_shard_id: str):
        shard_offset_mapping = {
            "q": 0,
            "k": self.num_heads * self.head_size,
            "v": (self.num_heads + self.num_kv_heads) * self.head_size,
1059
1060
            "total": (self.num_heads + self.num_kv_heads) * self.head_size
            + self.num_kv_heads * self.v_head_size,
1061
1062
1063
1064
1065
1066
1067
        }
        return shard_offset_mapping.get(loaded_shard_id)

    def _get_shard_size_mapping(self, loaded_shard_id: str):
        shard_size_mapping = {
            "q": self.num_heads * self.head_size,
            "k": self.num_kv_heads * self.head_size,
1068
            "v": self.num_kv_heads * self.v_head_size,
1069
1070
1071
        }
        return shard_size_mapping.get(loaded_shard_id)

1072
1073
1074
    def _load_fused_module_from_checkpoint(
        self, param: BasevLLMParameter, loaded_weight: torch.Tensor
    ):
1075
        """
1076
        Handle special case for models where QKV layers are already
1077
        fused on disk. In this case, we have no shard id. This function
1078
        determines the shard id by splitting these layers and then calls
1079
1080
1081
1082
1083
1084
1085
1086
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """
        shard_offsets = [
            # (shard_id, shard_offset, shard_size)
            ("q", 0, self.total_num_heads * self.head_size),
1087
1088
1089
1090
1091
1092
1093
1094
            (
                "k",
                self.total_num_heads * self.head_size,
                self.total_num_kv_heads * self.head_size,
            ),
            (
                "v",
                (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
1095
                self.total_num_kv_heads * self.v_head_size,
1096
            ),
1097
1098
1099
1100
1101
1102
        ]

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
            if (
                isinstance(param, (PackedColumnParameter, PackedvLLMParameter))
                and param.packed_dim == param.output_dim
            ):
                shard_size, shard_offset = param.adjust_shard_indexes_for_packing(
                    shard_size=shard_size, shard_offset=shard_offset
                )

            loaded_weight_shard = loaded_weight.narrow(
                param.output_dim, shard_offset, shard_size
            )
1114
1115
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

1116
1117
1118
1119
    def weight_loader_v2(
        self,
        param: BasevLLMParameter,
        loaded_weight: torch.Tensor,
1120
        loaded_shard_id: str | None = None,
1121
    ):
1122
        self.validate_shard_id(loaded_shard_id)
1123
        if loaded_shard_id is None:  # special case for certain models
1124
            if isinstance(param, PerTensorScaleParameter):
1125
1126
1127
                param.load_qkv_weight(
                    loaded_weight=loaded_weight, shard_id=0, tp_rank=self.tp_rank
                )
1128
                return
1129
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
1130
                param.load_qkv_weight(loaded_weight=loaded_weight, tp_rank=self.tp_rank)
1131
                return
1132
            # TODO: @dsikka - move to parameter.py
1133
1134
1135
1136
1137
1138
1139
1140
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id in ["q", "k", "v"]

        shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
        shard_size = self._get_shard_size_mapping(loaded_shard_id)

1141
        if isinstance(param, BlockQuantScaleParameter):
1142
1143
1144
1145
            weight_block_size = getattr(self, "weight_block_size", None)
            shard_size, shard_offset = adjust_block_scale_shard(
                weight_block_size, shard_size, shard_offset
            )
1146

1147
1148
1149
1150
1151
1152
1153
1154
        param.load_qkv_weight(
            loaded_weight=loaded_weight,
            num_heads=self.num_kv_head_replicas,
            shard_id=loaded_shard_id,
            shard_offset=shard_offset,
            shard_size=shard_size,
            tp_rank=self.tp_rank,
        )
1155

1156
1157
1158
1159
    def weight_loader(
        self,
        param: Parameter,
        loaded_weight: torch.Tensor,
1160
        loaded_shard_id: str | None = None,
1161
    ):
1162
        self.validate_shard_id(loaded_shard_id)
1163
1164
1165
1166
        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
1167
        if is_gguf_weight_type:
1168
            idx_map = {"q": 0, "k": 1, "v": 2}
1169
1170
1171
1172
            if loaded_shard_id is not None:
                param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
1173
                param.shard_weight_type = {k: loaded_weight.item() for k in idx_map}
1174
1175
            return

1176
1177
        if is_gguf_weight:
            output_dim = getattr(param, "output_dim", None)
1178
1179
            shard_size = loaded_weight.size(output_dim) // self.tp_size
            start_idx = self.tp_rank * shard_size
1180

1181
            if loaded_shard_id is not None:
1182
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
1183
1184
1185
1186
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
1187

1188
1189
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
1190

1191
1192
        # Special case for per-tensor scales in fused case.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
1193

1194
        if loaded_shard_id is None:
1195
1196
            # Loaded weight is already fused on disk (qkv).
            # (e.g., Phi-3's qkv_proj).
1197
            if output_dim is None:
1198
                if needs_scalar_to_array:
1199
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
1200
1201
                        param_data, loaded_weight, 0
                    )
1202

1203
1204
1205
1206
1207
1208
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            shard_offsets = [
                # (shard_id, shard_offset, shard_size)
                ("q", 0, self.total_num_heads * self.head_size),
1209
1210
1211
1212
1213
1214
1215
1216
                (
                    "k",
                    self.total_num_heads * self.head_size,
                    self.total_num_kv_heads * self.head_size,
                ),
                (
                    "v",
                    (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
1217
                    self.total_num_kv_heads * self.v_head_size,
1218
                ),
1219
            ]
1220
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1221

1222
1223
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
1224
                # Special case for Quantized Weights.
1225
1226
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
1227
1228
1229
1230
1231
1232
1233
                # Add check to adjust the size/offset for FP8 block scales
                if isinstance(param, BlockQuantScaleParameter):
                    weight_block_size = getattr(self, "weight_block_size", None)
                    shard_size, shard_offset = adjust_block_scale_shard(
                        weight_block_size, shard_size, shard_offset
                    )

1234
                if packed_dim == output_dim:
1235
1236
                    shard_size = shard_size // param.packed_factor
                    shard_offset = shard_offset // param.packed_factor
1237

1238
                    # Special case for Marlin.
1239
                    shard_size, shard_offset = adjust_marlin_shard(
1240
1241
                        param, shard_size, shard_offset
                    )
1242

1243
1244
1245
                if use_bitsandbytes_4bit:
                    orig_qkv_offsets = {
                        "q": (0, self.total_num_heads * self.head_size),
1246
1247
1248
1249
1250
1251
1252
                        "k": (
                            self.total_num_heads * self.head_size,
                            self.total_num_kv_heads * self.head_size,
                        ),
                        "v": (
                            (self.total_num_heads + self.total_num_kv_heads)
                            * self.head_size,
1253
                            self.total_num_kv_heads * self.v_head_size,
1254
1255
                        ),
                        "total": (
1256
1257
1258
                            (self.total_num_heads + self.total_num_kv_heads)
                            * self.head_size
                            + self.total_num_kv_heads * self.v_head_size,
1259
1260
                            0,
                        ),
1261
1262
1263
                    }

                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
1264
1265
                        param, orig_qkv_offsets, shard_id
                    )
1266

1267
                loaded_weight_shard = loaded_weight.narrow(
1268
1269
                    output_dim, shard_offset, shard_size
                )
1270
1271
1272
1273
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id in ["q", "k", "v"]
1274
1275

        # If output dim is defined, use the default loading process.
1276
1277
1278
1279
1280
1281
1282
1283
        if output_dim is not None:
            if loaded_shard_id == "q":
                shard_offset = 0
                shard_size = self.num_heads * self.head_size
            elif loaded_shard_id == "k":
                shard_offset = self.num_heads * self.head_size
                shard_size = self.num_kv_heads * self.head_size
            elif loaded_shard_id == "v":
1284
                shard_offset = (self.num_heads + self.num_kv_heads) * self.head_size
1285
                shard_size = self.num_kv_heads * self.v_head_size
1286
1287
1288
1289
1290
1291
1292

            if isinstance(param, BlockQuantScaleParameter):
                weight_block_size = getattr(self, "weight_block_size", None)
                shard_size, shard_offset = adjust_block_scale_shard(
                    weight_block_size, shard_size, shard_offset
                )

1293
            # Special case for Quantized Weights.
1294
1295
1296
1297
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
1298
1299
                shard_size = shard_size // param.packed_factor
                shard_offset = shard_offset // param.packed_factor
1300

1301
                # Special case for Marlin.
1302
                shard_size, shard_offset = adjust_marlin_shard(
1303
1304
                    param, shard_size, shard_offset
                )
1305

1306
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1307
1308
1309
1310
1311
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

1312
            if use_bitsandbytes_4bit:
1313
1314
                orig_qkv_offsets = {
                    "q": (0, self.num_heads * self.head_size),
1315
1316
1317
1318
1319
1320
                    "k": (
                        self.num_heads * self.head_size,
                        self.num_kv_heads * self.head_size,
                    ),
                    "v": (
                        (self.num_heads + self.num_kv_heads) * self.head_size,
1321
                        self.num_kv_heads * self.v_head_size,
1322
1323
                    ),
                    "total": (
1324
1325
                        (self.num_heads + self.num_kv_heads) * self.head_size
                        + self.num_kv_heads * self.v_head_size,
1326
1327
                        0,
                    ),
1328
                }
1329
                shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
1330
1331
                    param, orig_qkv_offsets, loaded_shard_id
                )
1332

1333
            param_data = param_data.narrow(output_dim, shard_offset, shard_size)
1334
            if loaded_shard_id == "q":
1335
                shard_rank = self.tp_rank
1336
            else:
1337
1338
                shard_rank = self.tp_rank // self.num_kv_head_replicas
            start_idx = shard_rank * shard_size
1339

1340
            if not is_sharded_weight:
1341
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
1342

1343
1344
1345
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
1346
1347
                param_data, loaded_weight, loaded_shard_id
            )
1348
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
1349
1350
1351
1352
1353
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "QKVParallelLinear, assume the weight is the same "
1354
1355
                    "for all partitions."
                )
1356

1357
1358
1359
1360
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)


1361
# --8<-- [start:row_parallel_linear]
1362
@PluggableLayer.register("row_parallel_linear")
1363
class RowParallelLinear(LinearBase):
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
    """Linear layer with row parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its first dimension and X along its second dimension as:
               -   -
              | A_1 |
              | .   |
          A = | .   |        X = [X_1, ..., X_p]
              | .   |
              | A_p |
               -   -
    Arguments:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias. Note that bias is not parallelized.
        input_is_parallel: If true, we assume that the input is already
                           split across the GPUs and we do not split
                           again.
        skip_bias_add: This was added to enable performance optimization where
                       bias can be fused with other element-wise operations.
                       We skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
1386
1387
1388
        reduce_results: If true, call all-reduce on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y = X_iA_i
1389
        quant_config: Quantization configure.
1390
1391
1392
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.down_proj)
        return_bias: If true, return bias together with outputs in forward pass.
1393
        disable_tp: If true, weights matrix won't be sharded through tp rank.
1394
1395
    """

1396
1397
    # --8<-- [end:row_parallel_linear]

1398
1399
1400
1401
1402
1403
1404
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        input_is_parallel: bool = True,
        skip_bias_add: bool = False,
1405
        params_dtype: torch.dtype | None = None,
1406
        reduce_results: bool = True,
1407
        quant_config: QuantizationConfig | None = None,
1408
1409
1410
        prefix: str = "",
        *,
        return_bias: bool = True,
1411
        disable_tp: bool = False,
1412
    ):
1413
        # Divide the weight matrix along the first dimension.
1414
1415
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
1416
1417
1418
1419
        self.input_size_per_partition = divide(input_size, self.tp_size)
        self.output_size_per_partition = output_size
        self.output_partition_sizes = [output_size]

1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
        super().__init__(
            input_size,
            output_size,
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
1430

1431
1432
1433
        self.input_is_parallel = input_is_parallel
        self.reduce_results = reduce_results

1434
        assert self.quant_method is not None
1435
1436
1437
        self.quant_method.create_weights(
            layer=self,
            input_size_per_partition=self.input_size_per_partition,
1438
            output_partition_sizes=self.output_partition_sizes,
1439
1440
1441
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
1442
            weight_loader=(
1443
1444
1445
1446
1447
                self.weight_loader_v2
                if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
                else self.weight_loader
            ),
        )
1448
        if not reduce_results and (bias and not skip_bias_add):
1449
1450
1451
1452
            raise ValueError(
                "When not reduce the results, adding bias to the "
                "results can lead to incorrect results"
            )
1453
1454

        if bias:
1455
1456
1457
1458
1459
1460
1461
1462
            self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype))
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
1463
1464
        else:
            self.register_parameter("bias", None)
1465
        self.update_param_tp_status()
1466
1467
1468

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        input_dim = getattr(param, "input_dim", None)
1469
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1470
1471
1472
1473
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
1474
1475
1476
1477
1478
1479
1480
1481
1482

        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
1483
1484
            weight_shape = list(loaded_weight.shape)
            if input_dim:
1485
                weight_shape[input_dim] = weight_shape[input_dim] // self.tp_size
1486
            param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
1487

1488
        param_data = param.data
1489
        if input_dim is not None and not is_sharded_weight:
1490
            shard_size = param_data.shape[input_dim]
1491
            start_idx = self.tp_rank * shard_size
1492
            loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)
1493

1494
1495
1496
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
1497
1498
            loaded_weight = loaded_weight.reshape(1)

1499
1500
1501
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

1502
    def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor):
1503
1504
1505
1506
1507
1508
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)

1509
1510
        param.load_row_parallel_weight(loaded_weight=loaded_weight)

1511
    def forward(
1512
1513
        self,
        input_,
1514
    ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
1515
1516
1517
        if self.input_is_parallel:
            input_parallel = input_
        else:
Jiayi Yan's avatar
Jiayi Yan committed
1518
            split_input = split_tensor_along_last_dim(
1519
1520
                input_, num_partitions=self.tp_size
            )
Jiayi Yan's avatar
Jiayi Yan committed
1521
            input_parallel = split_input[self.tp_rank].contiguous()
1522
1523

        # Matrix multiply.
1524
        assert self.quant_method is not None
1525
1526
1527
        # Only fuse bias add into GEMM for rank 0 (this ensures that
        # bias will not get added more than once in TP>1 case)
        bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
1528
1529
        output_parallel = self.quant_method.apply(self, input_parallel, bias_)

1530
        if self.reduce_results and self.tp_size > 1:
1531
            output = tensor_model_parallel_all_reduce(output_parallel)
1532
        else:
1533
1534
            output = output_parallel

1535
1536
        if not self.return_bias:
            return output
1537
        output_bias = self.bias if self.skip_bias_add else None
1538
        return output, output_bias
1539
1540

    def extra_repr(self) -> str:
1541
        s = f"in_features={self.input_size_per_partition}"
1542
1543
1544
1545
1546
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={self.tp_size}"
        s += f", reduce_results={self.reduce_results}"
        return s