"vscode:/vscode.git/clone" did not exist on "7c3604fb68031da36567151a9bdfe69e04de44b8"
linear.py 61.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
from abc import abstractmethod
6
7

import torch
8
from torch.nn.parameter import Parameter, UninitializedParameter
9

10
import vllm.envs as envs
11
12
13
14
15
16
17
18
from vllm.distributed import (
    divide,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
    split_tensor_along_last_dim,
    tensor_model_parallel_all_gather,
    tensor_model_parallel_all_reduce,
)
19
from vllm.logger import init_logger
20
from vllm.model_executor.custom_op import PluggableLayer
21
22
23
from vllm.model_executor.layers.batch_invariant import (
    linear_batch_invariant,
)
24
from vllm.model_executor.layers.quantization.base_config import (
25
26
27
    QuantizationConfig,
    QuantizeMethodBase,
)
28
29
30
from vllm.model_executor.layers.utils import (
    dispatch_unquantized_gemm,
)
31
32
33
34
35
36
37
38
39
from vllm.model_executor.parameter import (
    BasevLLMParameter,
    BlockQuantScaleParameter,
    ModelWeightParameter,
    PackedColumnParameter,
    PackedvLLMParameter,
    PerTensorScaleParameter,
    RowvLLMParameter,
)
40
from vllm.model_executor.utils import set_weight_attrs
41
from vllm.platforms import current_platform
42
43
44

logger = init_logger(__name__)

45
WEIGHT_LOADER_V2_SUPPORTED = [
46
    "UnquantizedLinearMethod",
47
    "CompressedTensorsLinearMethod",
48
    "CompressedTensorsLinearTransformMethod",
49
50
51
52
53
54
55
56
57
58
    "AWQMarlinLinearMethod",
    "AWQLinearMethod",
    "GPTQMarlinLinearMethod",
    "Fp8LinearMethod",
    "MarlinLinearMethod",
    "GPTQMarlin24LinearMethod",
    "TPUInt8LinearMethod",
    "GPTQLinearMethod",
    "FBGEMMFp8LinearMethod",
    "ModelOptFp8LinearMethod",
59
60
    "ModelOptFp8PcPtLinearMethod",
    "ModelOptFp8PbWoLinearMethod",
61
62
    "QuarkLinearMethod",
    "ModelOptNvFp4LinearMethod",
63
]
64

65

66
67
68
69
70
71
def register_weight_loader_v2_supported_method(cls):
    """Decorator to register a LinearMethod as supporting weight_loader_v2."""
    WEIGHT_LOADER_V2_SUPPORTED.append(cls.__name__)
    return cls


72
73
74
75
76
77
def adjust_marlin_shard(
    param: Parameter,
    shard_size: int,
    shard_offset: int,
) -> tuple[int, int]:
    marlin_tile_size: int | None = getattr(param, "marlin_tile_size", None)
78
79
80
81
82
83
    if marlin_tile_size is None:
        return shard_size, shard_offset

    return shard_size * marlin_tile_size, shard_offset * marlin_tile_size


84
85
86
87
88
def adjust_block_scale_shard(
    weight_block_size: tuple[int, ...] | None,
    shard_size: int,
    shard_offset: int,
) -> tuple[int, int]:
89
90
91
92
93
94
95
    assert weight_block_size is not None
    block_n = weight_block_size[0]
    shard_offset = (shard_offset + block_n - 1) // block_n
    shard_size = (shard_size + block_n - 1) // block_n
    return shard_size, shard_offset


96
def adjust_bitsandbytes_4bit_shard(
97
98
99
    param: Parameter,
    shard_offsets: dict[str, tuple[int, int]],
    loaded_shard_id: str,
100
) -> tuple[int, int]:
101
102
    """Adjust the quantization offsets and sizes for BitsAndBytes sharding."""

103
104
    total, _ = shard_offsets["total"]
    orig_offset, orig_size = shard_offsets[loaded_shard_id]
105
106
107
108
109
110
111
112

    quantized_total = param.data.shape[0]
    quantized_offset = orig_offset * quantized_total // total
    quantized_size = orig_size * quantized_total // total

    return quantized_size, quantized_offset


113
114
115
116
117
def adjust_scalar_to_fused_array(
    param_data: torch.Tensor,
    loaded_weight: torch.Tensor,
    shard_id: int | str,
) -> tuple[torch.Tensor, torch.Tensor]:
118
119
    """For fused modules (QKV and MLP) we have an array of length
    N that holds 1 scale for each "logical" matrix. So the param
120
121
    is an array of length N. The loaded_weight corresponds to
    one of the shards on disk. Here, we slice the param based on
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
    the shard_id for loading.
    """
    qkv_idxs = {"q": 0, "k": 1, "v": 2}

    if isinstance(shard_id, str):
        shard_id = qkv_idxs[shard_id]
    elif not isinstance(shard_id, int):
        raise ValueError(f"Unknown Shard Id {shard_id}")

    # AutoFP8 scales do not have a shape
    # compressed-tensors scales do have a shape
    if len(loaded_weight.shape) != 0:
        assert loaded_weight.shape[0] == 1
        loaded_weight = loaded_weight[0]

137
    return param_data[shard_id], loaded_weight
138
139


140
class LinearMethodBase(QuantizeMethodBase):
141
142
143
    """Base class for different (maybe quantized) linear methods."""

    @abstractmethod
144
145
146
147
148
149
150
151
152
153
154
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        """Create weights for a linear layer.
155
           The weights will be set as attributes of the layer.
156

157
158
159
        Args:
            layer: The layer that is using the LinearMethodBase factory.
            input_size_per_partition: Size of the weight input dim on rank X.
160
            output_partition_sizes: Sizes of the output dim of each logical
161
162
163
164
165
166
                weight on rank X. E.g., output_partition_sizes for QKVLinear
                is a list contains the width of Wq, Wk, Wv on rank X.
            input_size: Size of the input dim of the weight across all ranks.
            output_size: Size of the output dim of the weight across all ranks.
            params_dtype: Datatype of the parameters.
        """
167
168
169
        raise NotImplementedError

    @abstractmethod
170
171
172
173
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
174
        bias: torch.Tensor | None = None,
175
    ) -> torch.Tensor:
176
177
        """Apply the weights in layer to the input tensor.
        Expects create_weights to have been called before on the layer."""
178
179
180
181
        raise NotImplementedError


class UnquantizedLinearMethod(LinearMethodBase):
182
    """Linear method without quantization."""
183

184
185
186
187
188
189
190
191
192
193
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
194
195
196
197
        # This method creates unquantized linear weights.
        # The weights are not quantized, and they are not sharded.
        # The amount of memory allocated for the weights is
        # sum(output_partition_sizes) * input_size_per_partition.
198
199
200
201
202
203
204
205
206
207
208
        weight_loader = extra_weight_attrs.pop("weight_loader")
        weight = ModelWeightParameter(
            data=torch.empty(
                sum(output_partition_sizes),
                input_size_per_partition,
                dtype=params_dtype,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
209

210
211
        layer.register_parameter("weight", weight)
        set_weight_attrs(weight, extra_weight_attrs)
212

213
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
214
        if current_platform.is_cpu():
215
            from vllm.model_executor.layers.utils import dispatch_cpu_unquantized_gemm
216

217
            dispatch_cpu_unquantized_gemm(layer, remove_weight=True)
218

219
220
221
222
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
223
        bias: torch.Tensor | None = None,
224
    ) -> torch.Tensor:
225
        if envs.VLLM_BATCH_INVARIANT and current_platform.is_cuda_alike():
226
            return linear_batch_invariant(x, layer.weight, bias)
227
        return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
228
229


230
class LinearBase(PluggableLayer):
231
    """Base linear layer.
232
233
234
235
236
237

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
238
        quant_config: Quantization configure.
239
        prefix: Prefix for parameter names.
240
        return_bias: If true, return bias together with outputs in forward pass.
241
        disable_tp: If true, tensor parallelism will be disabled for this layer.
242
243
244
245
246
247
248
    """

    def __init__(
        self,
        input_size: int,
        output_size: int,
        skip_bias_add: bool = False,
249
250
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
251
        prefix: str = "",
252
253
        *,
        return_bias: bool = True,
254
        disable_tp: bool = False,
255
256
257
258
259
260
261
262
263
264
    ):
        super().__init__()

        # Keep input parameters
        self.input_size = input_size
        self.output_size = output_size
        self.skip_bias_add = skip_bias_add
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype
265
266
        self.quant_config = quant_config
        self.prefix = prefix
267
        self.allow_fp8_block_shape_mismatch = False
268
        if quant_config is None:
269
            self.quant_method: QuantizeMethodBase | None = UnquantizedLinearMethod()
270
        else:
271
            self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
272
        self.return_bias = return_bias
273
        self.disable_tp = disable_tp
274
275
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
276

277
    def update_param_tp_status(self):
278
279
280
281
        for param in self.parameters():
            if isinstance(param, BasevLLMParameter):
                param.tp_rank = self.tp_rank
                param.tp_size = self.tp_size
282
283


284
# --8<-- [start:replicated_linear]
285
@PluggableLayer.register("replicated_linear")
286
287
288
289
290
291
292
293
294
295
class ReplicatedLinear(LinearBase):
    """Replicated linear layer.

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        bias: If true, add bias.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
        quant_config: Quantization configure.
296
297
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
298
        return_bias: If true, return bias together with outputs in forward pass.
299
        disable_tp: Take no effect for replicated linear layers.
300
301
    """

302
303
    # --8<-- [end:replicated_linear]

304
305
306
307
308
309
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        skip_bias_add: bool = False,
310
311
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
312
313
314
        prefix: str = "",
        *,
        return_bias: bool = True,
315
        disable_tp: bool = False,
316
    ):
317
318
319
320
321
322
        # If MergedReplicatedLinear, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = self.output_sizes
        else:
            self.output_partition_sizes = [output_size]

323
324
325
326
327
328
329
330
331
332
        super().__init__(
            input_size,
            output_size,
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
333

334
335
        # All the linear layer supports quant method.
        assert self.quant_method is not None
336
337
338
339
340
341
342
343
344
        self.quant_method.create_weights(
            self,
            self.input_size,
            self.output_partition_sizes,
            self.input_size,
            self.output_size,
            self.params_dtype,
            weight_loader=self.weight_loader,
        )
345

346
347
        if bias:
            self.bias = Parameter(
348
349
350
351
352
353
354
355
356
                torch.empty(self.output_size, dtype=self.params_dtype)
            )
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
357
358
359
        else:
            self.register_parameter("bias", None)

360
361
362
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        # If the weight on disk does not have a shape, give it one
        # (such scales for AutoFp8).
363
364
365
366
367
368
369
370
371
372
373
        # Special case for GGUF

        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
            param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)

374
375
376
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)

377
378
        assert param.size() == loaded_weight.size(), (
            f"Tried to load weights of size {loaded_weight.size()}"
379
380
            f"to a parameter of size {param.size()}"
        )
381
382
        param.data.copy_(loaded_weight)

383
    def forward(
384
385
        self,
        x: torch.Tensor,
386
    ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
387
        bias = self.bias if not self.skip_bias_add else None
388
        assert self.quant_method is not None
389

390
        output = self.quant_method.apply(self, x, bias)
391

392
393
        if not self.return_bias:
            return output
394
        output_bias = self.bias if self.skip_bias_add else None
395
396
        return output, output_bias

397
398
399
400
401
402
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        return s

403

404
# --8<-- [start:column_parallel_linear]
405
@PluggableLayer.register("column_parallel_linear")
406
class ColumnParallelLinear(LinearBase):
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
    """Linear layer with column parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its second dimension as A = [A_1, ..., A_p].

    Args:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y_i = XA_i
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
423
        quant_config: Quantization configure.
424
        prefix: The name of the layer in the state dict, including all parents
425
                        (e.g. model.layers.0.qkv_proj)
426
427
        return_bias: If true, return bias together with outputs in forward pass.
        disable_tp: If true, weights matrix won't be sharded through tp rank.
428
429
    """

430
431
    # --8<-- [end:column_parallel_linear]

432
433
434
435
436
437
438
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
439
440
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
441
442
443
        prefix: str = "",
        *,
        return_bias: bool = True,
444
        disable_tp: bool = False,
445
    ):
446
        # Divide the weight matrix along the last dimension.
447
448
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
449
450
        self.input_size_per_partition = input_size
        self.output_size_per_partition = divide(output_size, self.tp_size)
451
452
453
454
        self.output_partition_sizes = [self.output_size_per_partition]
        # If QKV or MergedColumn, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = [
455
                divide(output_size, self.tp_size) for output_size in self.output_sizes
456
457
            ]

458
459
460
461
462
463
464
465
466
467
        super().__init__(
            input_size,
            output_size,
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
468

469
        self._maybe_allow_fp8_block_shape_mismatch()
470
471
472
        self.gather_output = gather_output

        assert self.quant_method is not None
473
474
        self.quant_method.create_weights(
            layer=self,
475
            input_size_per_partition=self.input_size_per_partition,
476
477
478
479
            output_partition_sizes=self.output_partition_sizes,
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
480
            weight_loader=(
481
482
483
484
485
                self.weight_loader_v2
                if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
                else self.weight_loader
            ),
        )
486
487
        if bias:
            self.bias = Parameter(
488
489
490
491
492
493
494
495
496
                torch.empty(self.output_size_per_partition, dtype=params_dtype)
            )
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
497
498
        else:
            self.register_parameter("bias", None)
499
        self.update_param_tp_status()
500

501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
    def _maybe_allow_fp8_block_shape_mismatch(self) -> None:
        quant_config = getattr(self, "quant_config", None)
        weight_block = getattr(quant_config, "weight_block_size", None)
        if (
            weight_block is None
            or len(weight_block) < 1
            or len(self.output_partition_sizes) <= 1
        ):
            return

        try:
            block_n = int(weight_block[0])
        except (ValueError, TypeError):
            return

        if block_n <= 0:
            return

        if any(size % block_n != 0 for size in self.output_partition_sizes):
            self.allow_fp8_block_shape_mismatch = True
            logger.debug(
                "Allowing FP8 block shape mismatch for %s (block_n=%d, partitions=%s)",
                getattr(self, "prefix", "<unknown>"),
                block_n,
                self.output_partition_sizes,
            )

528
529
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        output_dim = getattr(param, "output_dim", None)
530

531
532
533
534
535
536
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

537
538
539
540
541
542
543
544
        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
545
546
            final_shape = list(loaded_weight.shape)
            if output_dim is not None:
547
                assert final_shape[output_dim] % self.tp_size == 0
548
                final_shape[output_dim] = final_shape[output_dim] // self.tp_size
549
            param.materialize(final_shape, dtype=loaded_weight.dtype)
550

551
        param_data = param.data
552
        if output_dim is not None and not is_sharded_weight:
553
            shard_size = param_data.shape[output_dim]
554
            start_idx = self.tp_rank * shard_size
555
            loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
556
557
558
559
560

        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)
561

562
563
564
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

565
    def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor):
566
567
568
569
570
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)
571
572
        param.load_column_parallel_weight(loaded_weight=loaded_weight)

573
    def forward(
574
575
        self,
        input_,
576
    ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
577
578
579
        bias = self.bias if not self.skip_bias_add else None

        # Matrix multiply.
580
        assert self.quant_method is not None
581
        output_parallel = self.quant_method.apply(self, input_, bias)
582

583
        if self.gather_output and self.tp_size > 1:
584
585
586
587
            # All-gather across the partitions.
            output = tensor_model_parallel_all_gather(output_parallel)
        else:
            output = output_parallel
588

589
590
        if not self.return_bias:
            return output
591
        output_bias = self.bias if self.skip_bias_add else None
592
593
        return output, output_bias

594
595
596
597
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size_per_partition}"
        s += f", bias={self.bias is not None}"
598
        s += f", tp_size={self.tp_size}"
599
600
601
        s += f", gather_output={self.gather_output}"
        return s

602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620

class MergedColumnParallelLinear(ColumnParallelLinear):
    """Packed linear layers with column parallelism.

    Similar to ColumnParallelLinear, but the weight matrix is concatenated
    along the output dimension. When the weight matrix is loaded, the
    different partitions are sharded separately.

    Args:
        input_size: input dimension of the linear layer.
        output_sizes: list of output dimensions of the linear layer.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make the output
                       available to all GPUs, otherwise, every GPU will have
                       its own output.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
621
        quant_config: Quantization configure.
622
623
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
624
        return_bias: If true, return bias together with outputs in forward pass.
625
626
        disable_tp: If true, all weights matrix won't be sharded, this layer
                    will be treated as a "Replicated" MergedLinear.
627
628
    """

629
630
631
632
633
634
635
    def __init__(
        self,
        input_size: int,
        output_sizes: list[int],
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
636
637
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
638
639
640
        prefix: str = "",
        *,
        return_bias: bool = True,
641
        disable_tp: bool = False,
642
    ):
643
        self.output_sizes = output_sizes
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0

        assert all(output_size % self.tp_size == 0 for output_size in output_sizes)
        super().__init__(
            input_size=input_size,
            output_size=sum(output_sizes),
            bias=bias,
            gather_output=gather_output,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            quant_config=quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
James Fleming's avatar
James Fleming committed
660

661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
    def validate_shard_id(self, loaded_shard_id: int | tuple[int, ...] | None):
        if loaded_shard_id is None:
            return
        if isinstance(loaded_shard_id, tuple):
            for idx in loaded_shard_id:
                if not (0 <= idx < len(self.output_sizes)):
                    raise ValueError(
                        f"Shard id index {idx} should be between 0 and "
                        f"{len(self.output_sizes) - 1}. Got shard id {loaded_shard_id}."
                    )
            if len(loaded_shard_id) > 1 and any(
                b - a != 1 for a, b in zip(loaded_shard_id[:-1], loaded_shard_id[1:])
            ):
                raise ValueError(
                    "Shard id with multiple indices should be consecutive. "
                    f"Got shard id {loaded_shard_id}."
                )
            return
        elif isinstance(loaded_shard_id, int):
            if loaded_shard_id < 0 or loaded_shard_id >= len(self.output_sizes):
                raise ValueError(
                    f"Shard id should be between 0 and {len(self.output_sizes) - 1}. "
                    f"Got shard id {loaded_shard_id}."
                )
            return
        raise ValueError("This line should not be reached")

688
689
690
691
    def weight_loader(
        self,
        param: Parameter,
        loaded_weight: torch.Tensor,
692
        loaded_shard_id: tuple[int, ...] | int | None = None,
693
    ):
694
        self.validate_shard_id(loaded_shard_id)
695
696
697
698
        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
699
700
701
702
703
704
        if isinstance(loaded_shard_id, tuple) and (
            is_gguf_weight or is_gguf_weight_type
        ):
            raise NotImplementedError(
                "Shard id with multiple indices is not supported for GGUF."
            )
705
        if is_gguf_weight_type:
706
707
708
709
710
            if loaded_shard_id is not None:
                param.data[loaded_shard_id].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
                param.shard_weight_type = {
711
                    i: loaded_weight.item() for i, _ in enumerate(self.output_sizes)
712
                }
713
714
            return

715
716
        if is_gguf_weight:
            output_dim = getattr(param, "output_dim", None)
717
718
            shard_size = loaded_weight.size(output_dim) // self.tp_size
            start_idx = self.tp_rank * shard_size
719

720
            if loaded_shard_id is not None:
721
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
722
723
724
725
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
726

727
728
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
729
730
        # Special case for per-tensor scale to load scalar into fused array.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
731

732
        if loaded_shard_id is None or isinstance(loaded_shard_id, tuple):
733
734
            # Loaded weight is already fused on disk (mlp).
            # (e.g., Phi-3's gate_up_proj).
735
            if output_dim is None:
736
                if needs_scalar_to_array:
737
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
738
739
                        param_data, loaded_weight, 0
                    )
740

741
742
743
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
744
745
746
747
748
749

            output_sizes = (
                self.output_sizes[loaded_shard_id[0] : loaded_shard_id[-1] + 1]
                if loaded_shard_id is not None
                else self.output_sizes
            )
750
            current_shard_offset = 0
751
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
752
753
754
755
756
            if (
                use_bitsandbytes_4bit
                and isinstance(loaded_shard_id, tuple)
                and self.tp_size > 1
            ):
757
758
                raise NotImplementedError(
                    "Shard id with multiple indices is not supported "
759
                    "for BNB quantization with TP yet."
760
                )
761
            shard_offsets: list[tuple[int, int, int]] = []
762
            for i, output_size in enumerate(output_sizes):
763
764
765
766
                shard_offsets.append((i, current_shard_offset, output_size))
                current_shard_offset += output_size
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
767
                # Special case for Quantization.
768
769
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
770
771
772
773
774
775
776
                # Add check to adjust the size/offset for FP8 block scales
                if isinstance(param, BlockQuantScaleParameter):
                    weight_block_size = getattr(self, "weight_block_size", None)
                    shard_size, shard_offset = adjust_block_scale_shard(
                        weight_block_size, shard_size, shard_offset
                    )

777
                if packed_dim == output_dim:
778
779
                    shard_size = shard_size // param.packed_factor
                    shard_offset = shard_offset // param.packed_factor
780
                    # Special case for Marlin.
781
                    shard_size, shard_offset = adjust_marlin_shard(
782
783
                        param, shard_size, shard_offset
                    )
784

785
                if use_bitsandbytes_4bit:
786
787
788
789
790
791
792
                    index = list(itertools.accumulate([0] + self.output_sizes))
                    orig_offsets = {
                        str(i): (index[i], size)
                        for i, size in enumerate(self.output_sizes)
                    }
                    orig_offsets["total"] = (self.output_size, 0)
                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
793
794
                        param, orig_offsets, str(shard_id)
                    )
795

796
                loaded_weight_shard = loaded_weight.narrow(
797
798
                    output_dim, shard_offset, shard_size
                )
799
800
801
802
803
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id < len(self.output_sizes)
        if output_dim is not None:
804
805
            shard_offset = sum(self.output_sizes[:loaded_shard_id])
            shard_size = self.output_sizes[loaded_shard_id]
806
807
            shard_offset //= self.tp_size
            shard_size //= self.tp_size
808
809
810
811
812
813
814

            if isinstance(param, BlockQuantScaleParameter):
                weight_block_size = getattr(self, "weight_block_size", None)
                shard_size, shard_offset = adjust_block_scale_shard(
                    weight_block_size, shard_size, shard_offset
                )

815
            # Special case for quantization.
816
817
818
819
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
820
821
                shard_size = shard_size // param.packed_factor
                shard_offset = shard_offset // param.packed_factor
822
                # Special case for Marlin.
823
                shard_size, shard_offset = adjust_marlin_shard(
824
825
                    param, shard_size, shard_offset
                )
826

827
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
828
829
830
831
832
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

833
            if use_bitsandbytes_4bit:
834
835
836
837
838
839
840
841
                index = list(itertools.accumulate([0] + self.output_sizes))
                orig_offsets = {
                    str(i): (index[i], size) for i, size in enumerate(self.output_sizes)
                }
                orig_offsets["total"] = (self.output_size, 0)
                shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
                    param, orig_offsets, str(loaded_shard_id)
                )
842
            param_data = param_data.narrow(output_dim, shard_offset, shard_size)
843
            start_idx = self.tp_rank * shard_size
844
            if not is_sharded_weight:
845
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
846
847
848
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
849
850
                param_data, loaded_weight, loaded_shard_id
            )
851

852
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
853
854
855
856
857
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "MergedColumnParallelLinear, assume the weight is "
858
859
                    "the same for all partitions."
                )
860

861
862
863
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

864
    def _load_fused_module_from_checkpoint(
865
866
867
868
        self,
        param: BasevLLMParameter,
        loaded_weight: torch.Tensor,
        output_sizes: list[int] | None = None,
869
    ):
870
871
872
        """
        Handle special case for models where MLP layers are already
        fused on disk. In this case, we have no shard id. This function
873
        determines the shard id by splitting these layers and then calls
874
875
876
877
878
879
880
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """

        current_shard_offset = 0
881
        shard_offsets: list[tuple[int, int, int]] = []
882
883
        output_sizes = output_sizes or self.output_sizes
        for i, output_size in enumerate(output_sizes):
884
885
886
887
888
889
890
            shard_offsets.append((i, current_shard_offset, output_size))
            current_shard_offset += output_size

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
891
892
893
894
895
896
897
898
899
900
901
            if (
                isinstance(param, (PackedColumnParameter, PackedvLLMParameter))
                and param.packed_dim == param.output_dim
            ):
                shard_size, shard_offset = param.adjust_shard_indexes_for_packing(
                    shard_size=shard_size, shard_offset=shard_offset
                )

            loaded_weight_shard = loaded_weight.narrow(
                param.output_dim, shard_offset, shard_size
            )
902
903
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

904
905
906
907
    def weight_loader_v2(
        self,
        param: BasevLLMParameter,
        loaded_weight: torch.Tensor,
908
        loaded_shard_id: tuple[int, ...] | int | None = None,
909
    ):
910
        self.validate_shard_id(loaded_shard_id)
911
        if loaded_shard_id is None or isinstance(loaded_shard_id, tuple):
912
            if isinstance(param, PerTensorScaleParameter):
913
914
915
916
917
918
                if isinstance(loaded_shard_id, tuple):
                    for idx in loaded_shard_id:
                        param.load_merged_column_weight(
                            loaded_weight=loaded_weight, shard_id=idx
                        )
                else:
919
920
921
922
923
924
925
926
927
                    # When weights are already fused on disk (e.g. Phi-3's
                    # gate_up_proj), there is only a single scale for the
                    # entire fused matrix. Fill all slots with this scale
                    # to ensure that any subsequent reduction (like .max())
                    # works correctly while preserving the parameter shape.
                    for idx in range(param.data.shape[0]):
                        param.load_merged_column_weight(
                            loaded_weight=loaded_weight, shard_id=idx
                        )
928
                return
929
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
930
                param.load_merged_column_weight(loaded_weight=loaded_weight)
931
                return
932
933
934
935
936
937
938
939
940
941
942
            output_sizes = (
                [self.output_sizes[idx] for idx in loaded_shard_id]
                if loaded_shard_id
                else None
            )
            if isinstance(param, BlockQuantScaleParameter):
                weight_block_size = getattr(self, "weight_block_size", None)
                output_sizes = [
                    adjust_block_scale_shard(weight_block_size, size, 0)[0]
                    for size in (output_sizes or self.output_sizes)
                ]
943
            # TODO: @dsikka - move to parameter.py
944
945
946
            self._load_fused_module_from_checkpoint(
                param, loaded_weight, output_sizes=output_sizes
            )
947
948
949
950
            return

        assert loaded_shard_id < len(self.output_sizes)

951
952
        shard_offset = sum(self.output_sizes[:loaded_shard_id])
        shard_size = self.output_sizes[loaded_shard_id]
953
954
        shard_offset //= self.tp_size
        shard_size //= self.tp_size
955

956
        if isinstance(param, BlockQuantScaleParameter):
957
958
959
            weight_block_size = getattr(self, "weight_block_size", None)
            shard_size, shard_offset = adjust_block_scale_shard(
                weight_block_size, shard_size, shard_offset
960
            )
961

962
963
964
965
966
967
968
        param.load_merged_column_weight(
            loaded_weight=loaded_weight,
            shard_id=loaded_shard_id,
            shard_offset=shard_offset,
            shard_size=shard_size,
            tp_rank=self.tp_rank,
        )
969

970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991

class QKVParallelLinear(ColumnParallelLinear):
    """Linear layers for the attention's QKV transformation.

    Linear layers for the linear transformation of the query, key, and value
    vectors in the attention layer. The weight matrix is concatenated along
    the output dimension. The layer is parallelized along the head dimension.
    When the number of key/value heads is smaller than the number of query
    heads (e.g., multi-query/grouped-query attention), the key/value head may
    be replicated while the query heads are partitioned.

    Args:
        hidden_size: input hidden state size of the transformer.
        head_size: size of each attention head.
        total_num_heads: total number of attention query heads.
        total_num_kv_heads: total number of attention key/value heads. If
                            None, assume total_num_kv_heads = total_num_heads.
        bias: If true, add bias.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
992
        quant_config: Quantization configure.
993
994
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
995
        return_bias: If true, return bias together with outputs in forward pass.
996
        disable_tp: If true, weights matrix won't be sharded through tp rank.
997
998
    """

999
1000
1001
1002
1003
    def __init__(
        self,
        hidden_size: int,
        head_size: int,
        total_num_heads: int,
1004
        total_num_kv_heads: int | None = None,
1005
1006
        bias: bool = True,
        skip_bias_add: bool = False,
1007
1008
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
1009
1010
1011
        prefix: str = "",
        *,
        return_bias: bool = True,
1012
        disable_tp: bool = False,
1013
        v_head_size: int | None = None,
1014
    ):
1015
1016
        self.hidden_size = hidden_size
        self.head_size = head_size
1017
        self.v_head_size = v_head_size if v_head_size is not None else head_size
1018
1019
1020
1021
1022
        self.total_num_heads = total_num_heads
        if total_num_kv_heads is None:
            total_num_kv_heads = total_num_heads
        self.total_num_kv_heads = total_num_kv_heads
        # Divide the weight matrix along the last dimension.
1023
        tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
1024
1025
1026
        self.num_heads = divide(self.total_num_heads, tp_size)
        if tp_size >= self.total_num_kv_heads:
            self.num_kv_heads = 1
1027
            self.num_kv_head_replicas = divide(tp_size, self.total_num_kv_heads)
1028
1029
1030
1031
        else:
            self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
            self.num_kv_head_replicas = 1
        input_size = self.hidden_size
1032
        output_size = (
1033
1034
1035
1036
            self.num_heads * self.head_size
            + self.num_kv_heads * self.head_size
            + self.num_kv_heads * self.v_head_size
        ) * tp_size
1037
1038
1039
        self.output_sizes = [
            self.num_heads * self.head_size * tp_size,  # q_proj
            self.num_kv_heads * self.head_size * tp_size,  # k_proj
1040
            self.num_kv_heads * self.v_head_size * tp_size,  # v_proj
James Fleming's avatar
James Fleming committed
1041
1042
        ]

1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
        super().__init__(
            input_size=input_size,
            output_size=output_size,
            bias=bias,
            gather_output=False,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            quant_config=quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
1055

1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
    def validate_shard_id(self, loaded_shard_id: str | None):
        if loaded_shard_id is None:
            return
        if isinstance(loaded_shard_id, str):
            if loaded_shard_id not in ["q", "k", "v"]:
                raise ValueError(
                    "Shard id for QKVParallelLinear should be 'q', 'k', or 'v', "
                    f"got shard id {loaded_shard_id}."
                )
            return
        raise ValueError("This line should not be reached")

1068
1069
1070
1071
1072
    def _get_shard_offset_mapping(self, loaded_shard_id: str):
        shard_offset_mapping = {
            "q": 0,
            "k": self.num_heads * self.head_size,
            "v": (self.num_heads + self.num_kv_heads) * self.head_size,
1073
1074
            "total": (self.num_heads + self.num_kv_heads) * self.head_size
            + self.num_kv_heads * self.v_head_size,
1075
1076
1077
1078
1079
1080
1081
        }
        return shard_offset_mapping.get(loaded_shard_id)

    def _get_shard_size_mapping(self, loaded_shard_id: str):
        shard_size_mapping = {
            "q": self.num_heads * self.head_size,
            "k": self.num_kv_heads * self.head_size,
1082
            "v": self.num_kv_heads * self.v_head_size,
1083
1084
1085
        }
        return shard_size_mapping.get(loaded_shard_id)

1086
1087
1088
    def _load_fused_module_from_checkpoint(
        self, param: BasevLLMParameter, loaded_weight: torch.Tensor
    ):
1089
        """
1090
        Handle special case for models where QKV layers are already
1091
        fused on disk. In this case, we have no shard id. This function
1092
        determines the shard id by splitting these layers and then calls
1093
1094
1095
1096
1097
1098
1099
1100
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """
        shard_offsets = [
            # (shard_id, shard_offset, shard_size)
            ("q", 0, self.total_num_heads * self.head_size),
1101
1102
1103
1104
1105
1106
1107
1108
            (
                "k",
                self.total_num_heads * self.head_size,
                self.total_num_kv_heads * self.head_size,
            ),
            (
                "v",
                (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
1109
                self.total_num_kv_heads * self.v_head_size,
1110
            ),
1111
1112
1113
1114
1115
1116
        ]

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
            if (
                isinstance(param, (PackedColumnParameter, PackedvLLMParameter))
                and param.packed_dim == param.output_dim
            ):
                shard_size, shard_offset = param.adjust_shard_indexes_for_packing(
                    shard_size=shard_size, shard_offset=shard_offset
                )

            loaded_weight_shard = loaded_weight.narrow(
                param.output_dim, shard_offset, shard_size
            )
1128
1129
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

1130
1131
1132
1133
    def weight_loader_v2(
        self,
        param: BasevLLMParameter,
        loaded_weight: torch.Tensor,
1134
        loaded_shard_id: str | None = None,
1135
    ):
1136
        self.validate_shard_id(loaded_shard_id)
1137
        if loaded_shard_id is None:  # special case for certain models
1138
            if isinstance(param, PerTensorScaleParameter):
1139
1140
1141
1142
1143
1144
1145
1146
1147
                # When weights are already fused on disk (e.g. Phi-3's
                # qkv_proj), there is only a single scale for the entire
                # fused matrix. Fill all slots (q, k, v) with this scale
                # to ensure that any subsequent reduction (like .max())
                # works correctly while preserving the parameter shape.
                for idx in range(param.data.shape[0]):
                    param.load_qkv_weight(
                        loaded_weight=loaded_weight, shard_id=idx, tp_rank=self.tp_rank
                    )
1148
                return
1149
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
1150
                param.load_qkv_weight(loaded_weight=loaded_weight, tp_rank=self.tp_rank)
1151
                return
1152
            # TODO: @dsikka - move to parameter.py
1153
1154
1155
1156
1157
1158
1159
1160
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id in ["q", "k", "v"]

        shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
        shard_size = self._get_shard_size_mapping(loaded_shard_id)

1161
        if isinstance(param, BlockQuantScaleParameter):
1162
1163
1164
1165
            weight_block_size = getattr(self, "weight_block_size", None)
            shard_size, shard_offset = adjust_block_scale_shard(
                weight_block_size, shard_size, shard_offset
            )
1166

1167
1168
1169
1170
1171
1172
1173
1174
        param.load_qkv_weight(
            loaded_weight=loaded_weight,
            num_heads=self.num_kv_head_replicas,
            shard_id=loaded_shard_id,
            shard_offset=shard_offset,
            shard_size=shard_size,
            tp_rank=self.tp_rank,
        )
1175

1176
1177
1178
1179
    def weight_loader(
        self,
        param: Parameter,
        loaded_weight: torch.Tensor,
1180
        loaded_shard_id: str | None = None,
1181
    ):
1182
        self.validate_shard_id(loaded_shard_id)
1183
1184
1185
1186
        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
1187
        if is_gguf_weight_type:
1188
            idx_map = {"q": 0, "k": 1, "v": 2}
1189
1190
1191
1192
            if loaded_shard_id is not None:
                param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
1193
                param.shard_weight_type = {k: loaded_weight.item() for k in idx_map}
1194
1195
            return

1196
1197
        if is_gguf_weight:
            output_dim = getattr(param, "output_dim", None)
1198
1199
            shard_size = loaded_weight.size(output_dim) // self.tp_size
            start_idx = self.tp_rank * shard_size
1200

1201
            if loaded_shard_id is not None:
1202
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
1203
1204
1205
1206
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
1207

1208
1209
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
1210

1211
1212
        # Special case for per-tensor scales in fused case.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
1213

1214
        if loaded_shard_id is None:
1215
1216
            # Loaded weight is already fused on disk (qkv).
            # (e.g., Phi-3's qkv_proj).
1217
            if output_dim is None:
1218
                if needs_scalar_to_array:
1219
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
1220
1221
                        param_data, loaded_weight, 0
                    )
1222

1223
1224
1225
1226
1227
1228
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            shard_offsets = [
                # (shard_id, shard_offset, shard_size)
                ("q", 0, self.total_num_heads * self.head_size),
1229
1230
1231
1232
1233
1234
1235
1236
                (
                    "k",
                    self.total_num_heads * self.head_size,
                    self.total_num_kv_heads * self.head_size,
                ),
                (
                    "v",
                    (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
1237
                    self.total_num_kv_heads * self.v_head_size,
1238
                ),
1239
            ]
1240
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1241

1242
1243
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
1244
                # Special case for Quantized Weights.
1245
1246
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
1247
1248
1249
1250
1251
1252
1253
                # Add check to adjust the size/offset for FP8 block scales
                if isinstance(param, BlockQuantScaleParameter):
                    weight_block_size = getattr(self, "weight_block_size", None)
                    shard_size, shard_offset = adjust_block_scale_shard(
                        weight_block_size, shard_size, shard_offset
                    )

1254
                if packed_dim == output_dim:
1255
1256
                    shard_size = shard_size // param.packed_factor
                    shard_offset = shard_offset // param.packed_factor
1257

1258
                    # Special case for Marlin.
1259
                    shard_size, shard_offset = adjust_marlin_shard(
1260
1261
                        param, shard_size, shard_offset
                    )
1262

1263
1264
1265
                if use_bitsandbytes_4bit:
                    orig_qkv_offsets = {
                        "q": (0, self.total_num_heads * self.head_size),
1266
1267
1268
1269
1270
1271
1272
                        "k": (
                            self.total_num_heads * self.head_size,
                            self.total_num_kv_heads * self.head_size,
                        ),
                        "v": (
                            (self.total_num_heads + self.total_num_kv_heads)
                            * self.head_size,
1273
                            self.total_num_kv_heads * self.v_head_size,
1274
1275
                        ),
                        "total": (
1276
1277
1278
                            (self.total_num_heads + self.total_num_kv_heads)
                            * self.head_size
                            + self.total_num_kv_heads * self.v_head_size,
1279
1280
                            0,
                        ),
1281
1282
1283
                    }

                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
1284
1285
                        param, orig_qkv_offsets, shard_id
                    )
1286

1287
                loaded_weight_shard = loaded_weight.narrow(
1288
1289
                    output_dim, shard_offset, shard_size
                )
1290
1291
1292
1293
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id in ["q", "k", "v"]
1294
1295

        # If output dim is defined, use the default loading process.
1296
1297
1298
1299
1300
1301
1302
1303
        if output_dim is not None:
            if loaded_shard_id == "q":
                shard_offset = 0
                shard_size = self.num_heads * self.head_size
            elif loaded_shard_id == "k":
                shard_offset = self.num_heads * self.head_size
                shard_size = self.num_kv_heads * self.head_size
            elif loaded_shard_id == "v":
1304
                shard_offset = (self.num_heads + self.num_kv_heads) * self.head_size
1305
                shard_size = self.num_kv_heads * self.v_head_size
1306
1307
1308
1309
1310
1311
1312

            if isinstance(param, BlockQuantScaleParameter):
                weight_block_size = getattr(self, "weight_block_size", None)
                shard_size, shard_offset = adjust_block_scale_shard(
                    weight_block_size, shard_size, shard_offset
                )

1313
            # Special case for Quantized Weights.
1314
1315
1316
1317
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
1318
1319
                shard_size = shard_size // param.packed_factor
                shard_offset = shard_offset // param.packed_factor
1320

1321
                # Special case for Marlin.
1322
                shard_size, shard_offset = adjust_marlin_shard(
1323
1324
                    param, shard_size, shard_offset
                )
1325

1326
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1327
1328
1329
1330
1331
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

1332
            if use_bitsandbytes_4bit:
1333
1334
                orig_qkv_offsets = {
                    "q": (0, self.num_heads * self.head_size),
1335
1336
1337
1338
1339
1340
                    "k": (
                        self.num_heads * self.head_size,
                        self.num_kv_heads * self.head_size,
                    ),
                    "v": (
                        (self.num_heads + self.num_kv_heads) * self.head_size,
1341
                        self.num_kv_heads * self.v_head_size,
1342
1343
                    ),
                    "total": (
1344
1345
                        (self.num_heads + self.num_kv_heads) * self.head_size
                        + self.num_kv_heads * self.v_head_size,
1346
1347
                        0,
                    ),
1348
                }
1349
                shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
1350
1351
                    param, orig_qkv_offsets, loaded_shard_id
                )
1352

1353
            param_data = param_data.narrow(output_dim, shard_offset, shard_size)
1354
            if loaded_shard_id == "q":
1355
                shard_rank = self.tp_rank
1356
            else:
1357
1358
                shard_rank = self.tp_rank // self.num_kv_head_replicas
            start_idx = shard_rank * shard_size
1359

1360
            if not is_sharded_weight:
1361
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
1362

1363
1364
1365
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
1366
1367
                param_data, loaded_weight, loaded_shard_id
            )
1368
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
1369
1370
1371
1372
1373
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "QKVParallelLinear, assume the weight is the same "
1374
1375
                    "for all partitions."
                )
1376

1377
1378
1379
1380
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)


1381
# --8<-- [start:row_parallel_linear]
1382
@PluggableLayer.register("row_parallel_linear")
1383
class RowParallelLinear(LinearBase):
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
    """Linear layer with row parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its first dimension and X along its second dimension as:
               -   -
              | A_1 |
              | .   |
          A = | .   |        X = [X_1, ..., X_p]
              | .   |
              | A_p |
               -   -
    Arguments:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias. Note that bias is not parallelized.
        input_is_parallel: If true, we assume that the input is already
                           split across the GPUs and we do not split
                           again.
        skip_bias_add: This was added to enable performance optimization where
                       bias can be fused with other element-wise operations.
                       We skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
1406
1407
1408
        reduce_results: If true, call all-reduce on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y = X_iA_i
1409
        quant_config: Quantization configure.
1410
1411
1412
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.down_proj)
        return_bias: If true, return bias together with outputs in forward pass.
1413
        disable_tp: If true, weights matrix won't be sharded through tp rank.
1414
1415
    """

1416
1417
    # --8<-- [end:row_parallel_linear]

1418
1419
1420
1421
1422
1423
1424
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        input_is_parallel: bool = True,
        skip_bias_add: bool = False,
1425
        params_dtype: torch.dtype | None = None,
1426
        reduce_results: bool = True,
1427
        quant_config: QuantizationConfig | None = None,
1428
1429
1430
        prefix: str = "",
        *,
        return_bias: bool = True,
1431
        disable_tp: bool = False,
1432
    ):
1433
        # Divide the weight matrix along the first dimension.
1434
1435
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
1436
1437
1438
1439
        self.input_size_per_partition = divide(input_size, self.tp_size)
        self.output_size_per_partition = output_size
        self.output_partition_sizes = [output_size]

1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
        super().__init__(
            input_size,
            output_size,
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
1450

1451
1452
1453
        self.input_is_parallel = input_is_parallel
        self.reduce_results = reduce_results

1454
        assert self.quant_method is not None
1455
1456
1457
        self.quant_method.create_weights(
            layer=self,
            input_size_per_partition=self.input_size_per_partition,
1458
            output_partition_sizes=self.output_partition_sizes,
1459
1460
1461
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
1462
            weight_loader=(
1463
1464
1465
1466
1467
                self.weight_loader_v2
                if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
                else self.weight_loader
            ),
        )
1468
        if not reduce_results and (bias and not skip_bias_add):
1469
1470
1471
1472
            raise ValueError(
                "When not reduce the results, adding bias to the "
                "results can lead to incorrect results"
            )
1473
1474

        if bias:
1475
1476
1477
1478
1479
1480
1481
1482
            self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype))
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
1483
1484
        else:
            self.register_parameter("bias", None)
1485
        self.update_param_tp_status()
1486
1487
1488

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        input_dim = getattr(param, "input_dim", None)
1489
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1490
1491
1492
1493
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
1494
1495
1496
1497
1498
1499
1500
1501
1502

        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
1503
1504
            weight_shape = list(loaded_weight.shape)
            if input_dim:
1505
                weight_shape[input_dim] = weight_shape[input_dim] // self.tp_size
1506
            param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
1507

1508
        param_data = param.data
1509
        if input_dim is not None and not is_sharded_weight:
1510
            shard_size = param_data.shape[input_dim]
1511
            start_idx = self.tp_rank * shard_size
1512
            loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)
1513

1514
1515
1516
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
1517
1518
            loaded_weight = loaded_weight.reshape(1)

1519
1520
1521
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

1522
    def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor):
1523
1524
1525
1526
1527
1528
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)

1529
1530
        param.load_row_parallel_weight(loaded_weight=loaded_weight)

1531
    def forward(
1532
1533
        self,
        input_,
1534
    ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
1535
1536
1537
        if self.input_is_parallel:
            input_parallel = input_
        else:
Jiayi Yan's avatar
Jiayi Yan committed
1538
            split_input = split_tensor_along_last_dim(
1539
1540
                input_, num_partitions=self.tp_size
            )
Jiayi Yan's avatar
Jiayi Yan committed
1541
            input_parallel = split_input[self.tp_rank].contiguous()
1542
1543

        # Matrix multiply.
1544
        assert self.quant_method is not None
1545
1546
1547
        # Only fuse bias add into GEMM for rank 0 (this ensures that
        # bias will not get added more than once in TP>1 case)
        bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
1548
1549
        output_parallel = self.quant_method.apply(self, input_parallel, bias_)

1550
        if self.reduce_results and self.tp_size > 1:
1551
            output = tensor_model_parallel_all_reduce(output_parallel)
1552
        else:
1553
1554
            output = output_parallel

1555
1556
        if not self.return_bias:
            return output
1557
        output_bias = self.bias if self.skip_bias_add else None
1558
        return output, output_bias
1559
1560

    def extra_repr(self) -> str:
1561
        s = f"in_features={self.input_size_per_partition}"
1562
1563
1564
1565
1566
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={self.tp_size}"
        s += f", reduce_results={self.reduce_results}"
        return s