linear.py 61.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
from abc import abstractmethod
6
7

import torch
8
from torch.nn.parameter import Parameter, UninitializedParameter
9

10
import vllm.envs as envs
11
12
13
14
15
16
17
18
from vllm.distributed import (
    divide,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
    split_tensor_along_last_dim,
    tensor_model_parallel_all_gather,
    tensor_model_parallel_all_reduce,
)
19
from vllm.logger import init_logger
20
from vllm.model_executor.custom_op import PluggableLayer
21
22
23
from vllm.model_executor.layers.batch_invariant import (
    linear_batch_invariant,
)
24
from vllm.model_executor.layers.quantization.base_config import (
25
26
27
    QuantizationConfig,
    QuantizeMethodBase,
)
28
29
30
from vllm.model_executor.layers.utils import (
    dispatch_unquantized_gemm,
)
31
32
33
34
35
36
37
38
39
from vllm.model_executor.parameter import (
    BasevLLMParameter,
    BlockQuantScaleParameter,
    ModelWeightParameter,
    PackedColumnParameter,
    PackedvLLMParameter,
    PerTensorScaleParameter,
    RowvLLMParameter,
)
40
from vllm.model_executor.utils import set_weight_attrs
41
from vllm.platforms import current_platform
42
43
44

logger = init_logger(__name__)

45
WEIGHT_LOADER_V2_SUPPORTED = [
46
    "UnquantizedLinearMethod",
47
    "CompressedTensorsLinearMethod",
48
    "CompressedTensorsLinearTransformMethod",
49
50
51
52
53
54
55
56
57
58
    "AWQMarlinLinearMethod",
    "AWQLinearMethod",
    "GPTQMarlinLinearMethod",
    "Fp8LinearMethod",
    "MarlinLinearMethod",
    "GPTQMarlin24LinearMethod",
    "TPUInt8LinearMethod",
    "GPTQLinearMethod",
    "FBGEMMFp8LinearMethod",
    "ModelOptFp8LinearMethod",
59
60
    "ModelOptFp8PcPtLinearMethod",
    "ModelOptFp8PbWoLinearMethod",
61
62
    "QuarkLinearMethod",
    "ModelOptNvFp4LinearMethod",
63
    "HummingLinearMethod",
64
]
65

66

67
68
69
70
71
72
def register_weight_loader_v2_supported_method(cls):
    """Decorator to register a LinearMethod as supporting weight_loader_v2."""
    WEIGHT_LOADER_V2_SUPPORTED.append(cls.__name__)
    return cls


73
74
75
76
77
78
def adjust_marlin_shard(
    param: Parameter,
    shard_size: int,
    shard_offset: int,
) -> tuple[int, int]:
    marlin_tile_size: int | None = getattr(param, "marlin_tile_size", None)
79
80
81
82
83
84
    if marlin_tile_size is None:
        return shard_size, shard_offset

    return shard_size * marlin_tile_size, shard_offset * marlin_tile_size


85
86
87
88
89
def adjust_block_scale_shard(
    weight_block_size: tuple[int, ...] | None,
    shard_size: int,
    shard_offset: int,
) -> tuple[int, int]:
90
91
92
93
94
95
96
    assert weight_block_size is not None
    block_n = weight_block_size[0]
    shard_offset = (shard_offset + block_n - 1) // block_n
    shard_size = (shard_size + block_n - 1) // block_n
    return shard_size, shard_offset


97
def adjust_bitsandbytes_4bit_shard(
98
99
100
    param: Parameter,
    shard_offsets: dict[str, tuple[int, int]],
    loaded_shard_id: str,
101
) -> tuple[int, int]:
102
103
    """Adjust the quantization offsets and sizes for BitsAndBytes sharding."""

104
105
    total, _ = shard_offsets["total"]
    orig_offset, orig_size = shard_offsets[loaded_shard_id]
106
107
108
109
110
111
112
113

    quantized_total = param.data.shape[0]
    quantized_offset = orig_offset * quantized_total // total
    quantized_size = orig_size * quantized_total // total

    return quantized_size, quantized_offset


114
115
116
117
118
def adjust_scalar_to_fused_array(
    param_data: torch.Tensor,
    loaded_weight: torch.Tensor,
    shard_id: int | str,
) -> tuple[torch.Tensor, torch.Tensor]:
119
120
    """For fused modules (QKV and MLP) we have an array of length
    N that holds 1 scale for each "logical" matrix. So the param
121
122
    is an array of length N. The loaded_weight corresponds to
    one of the shards on disk. Here, we slice the param based on
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
    the shard_id for loading.
    """
    qkv_idxs = {"q": 0, "k": 1, "v": 2}

    if isinstance(shard_id, str):
        shard_id = qkv_idxs[shard_id]
    elif not isinstance(shard_id, int):
        raise ValueError(f"Unknown Shard Id {shard_id}")

    # AutoFP8 scales do not have a shape
    # compressed-tensors scales do have a shape
    if len(loaded_weight.shape) != 0:
        assert loaded_weight.shape[0] == 1
        loaded_weight = loaded_weight[0]

138
    return param_data[shard_id], loaded_weight
139
140


141
class LinearMethodBase(QuantizeMethodBase):
142
143
144
    """Base class for different (maybe quantized) linear methods."""

    @abstractmethod
145
146
147
148
149
150
151
152
153
154
155
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        """Create weights for a linear layer.
156
           The weights will be set as attributes of the layer.
157

158
159
160
        Args:
            layer: The layer that is using the LinearMethodBase factory.
            input_size_per_partition: Size of the weight input dim on rank X.
161
            output_partition_sizes: Sizes of the output dim of each logical
162
163
164
165
166
167
                weight on rank X. E.g., output_partition_sizes for QKVLinear
                is a list contains the width of Wq, Wk, Wv on rank X.
            input_size: Size of the input dim of the weight across all ranks.
            output_size: Size of the output dim of the weight across all ranks.
            params_dtype: Datatype of the parameters.
        """
168
169
170
        raise NotImplementedError

    @abstractmethod
171
172
173
174
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
175
        bias: torch.Tensor | None = None,
176
    ) -> torch.Tensor:
177
178
        """Apply the weights in layer to the input tensor.
        Expects create_weights to have been called before on the layer."""
179
180
181
182
        raise NotImplementedError


class UnquantizedLinearMethod(LinearMethodBase):
183
    """Linear method without quantization."""
184

185
186
187
188
189
190
191
192
193
194
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
195
196
197
198
        # This method creates unquantized linear weights.
        # The weights are not quantized, and they are not sharded.
        # The amount of memory allocated for the weights is
        # sum(output_partition_sizes) * input_size_per_partition.
199
200
201
202
203
204
205
206
207
208
209
        weight_loader = extra_weight_attrs.pop("weight_loader")
        weight = ModelWeightParameter(
            data=torch.empty(
                sum(output_partition_sizes),
                input_size_per_partition,
                dtype=params_dtype,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
210

211
212
        layer.register_parameter("weight", weight)
        set_weight_attrs(weight, extra_weight_attrs)
213

214
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
215
        if current_platform.is_cpu():
216
            from vllm.model_executor.layers.utils import dispatch_cpu_unquantized_gemm
217

218
            dispatch_cpu_unquantized_gemm(layer, remove_weight=True)
219

220
221
222
223
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
224
        bias: torch.Tensor | None = None,
225
    ) -> torch.Tensor:
226
        if envs.VLLM_BATCH_INVARIANT and current_platform.is_cuda_alike():
227
            return linear_batch_invariant(x, layer.weight, bias)
228
        return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
229
230


231
class LinearBase(PluggableLayer):
232
    """Base linear layer.
233
234
235
236
237
238

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
239
        quant_config: Quantization configure.
240
        prefix: Prefix for parameter names.
241
        return_bias: If true, return bias together with outputs in forward pass.
242
        disable_tp: If true, tensor parallelism will be disabled for this layer.
243
244
245
246
247
248
    """

    def __init__(
        self,
        input_size: int,
        output_size: int,
249
        bias: bool = False,
250
        skip_bias_add: bool = False,
251
252
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
253
        prefix: str = "",
254
255
        *,
        return_bias: bool = True,
256
        disable_tp: bool = False,
257
258
259
260
261
262
    ):
        super().__init__()

        # Keep input parameters
        self.input_size = input_size
        self.output_size = output_size
263
        self.has_bias = bias
264
265
266
267
        self.skip_bias_add = skip_bias_add
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype
268
269
        self.quant_config = quant_config
        self.prefix = prefix
270
        self.allow_fp8_block_shape_mismatch = False
271
        if quant_config is None:
272
            self.quant_method: QuantizeMethodBase | None = UnquantizedLinearMethod()
273
        else:
274
            self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
275
        self.return_bias = return_bias
276
        self.disable_tp = disable_tp
277
278
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
279

280
    def update_param_tp_status(self):
281
282
283
284
        for param in self.parameters():
            if isinstance(param, BasevLLMParameter):
                param.tp_rank = self.tp_rank
                param.tp_size = self.tp_size
285
286


287
# --8<-- [start:replicated_linear]
288
@PluggableLayer.register("replicated_linear")
289
290
291
292
293
294
295
296
297
298
class ReplicatedLinear(LinearBase):
    """Replicated linear layer.

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        bias: If true, add bias.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
        quant_config: Quantization configure.
299
300
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
301
        return_bias: If true, return bias together with outputs in forward pass.
302
        disable_tp: Take no effect for replicated linear layers.
303
304
    """

305
306
    # --8<-- [end:replicated_linear]

307
308
309
310
311
312
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        skip_bias_add: bool = False,
313
314
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
315
316
317
        prefix: str = "",
        *,
        return_bias: bool = True,
318
        disable_tp: bool = False,
319
    ):
320
321
322
323
324
325
        # If MergedReplicatedLinear, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = self.output_sizes
        else:
            self.output_partition_sizes = [output_size]

326
327
328
        super().__init__(
            input_size,
            output_size,
329
            bias,
330
331
332
333
334
335
336
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
337

338
339
        # All the linear layer supports quant method.
        assert self.quant_method is not None
340
341
342
343
344
345
346
347
348
        self.quant_method.create_weights(
            self,
            self.input_size,
            self.output_partition_sizes,
            self.input_size,
            self.output_size,
            self.params_dtype,
            weight_loader=self.weight_loader,
        )
349

350
351
        if bias:
            self.bias = Parameter(
352
353
354
355
356
357
358
359
360
                torch.empty(self.output_size, dtype=self.params_dtype)
            )
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
361
362
363
        else:
            self.register_parameter("bias", None)

364
365
366
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        # If the weight on disk does not have a shape, give it one
        # (such scales for AutoFp8).
367
368
369
370
371
372
373
374
375
376
377
        # Special case for GGUF

        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
            param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)

378
379
380
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)

381
382
        assert param.size() == loaded_weight.size(), (
            f"Tried to load weights of size {loaded_weight.size()}"
383
384
            f"to a parameter of size {param.size()}"
        )
385
386
        param.data.copy_(loaded_weight)

387
    def forward(
388
389
        self,
        x: torch.Tensor,
390
    ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
391
        bias = self.bias if not self.skip_bias_add else None
392
        assert self.quant_method is not None
393

394
        output = self.quant_method.apply(self, x, bias)
395

396
397
        if not self.return_bias:
            return output
398
        output_bias = self.bias if self.skip_bias_add else None
399
400
        return output, output_bias

401
402
403
404
405
406
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        return s

407

408
# --8<-- [start:column_parallel_linear]
409
@PluggableLayer.register("column_parallel_linear")
410
class ColumnParallelLinear(LinearBase):
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
    """Linear layer with column parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its second dimension as A = [A_1, ..., A_p].

    Args:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y_i = XA_i
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
427
        quant_config: Quantization configure.
428
        prefix: The name of the layer in the state dict, including all parents
429
                        (e.g. model.layers.0.qkv_proj)
430
431
        return_bias: If true, return bias together with outputs in forward pass.
        disable_tp: If true, weights matrix won't be sharded through tp rank.
432
433
    """

434
435
    # --8<-- [end:column_parallel_linear]

436
437
438
439
440
441
442
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
443
444
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
445
446
447
        prefix: str = "",
        *,
        return_bias: bool = True,
448
        disable_tp: bool = False,
449
    ):
450
        # Divide the weight matrix along the last dimension.
451
452
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
453
454
        self.input_size_per_partition = input_size
        self.output_size_per_partition = divide(output_size, self.tp_size)
455
456
457
458
        self.output_partition_sizes = [self.output_size_per_partition]
        # If QKV or MergedColumn, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = [
459
                divide(output_size, self.tp_size) for output_size in self.output_sizes
460
461
            ]

462
463
464
        super().__init__(
            input_size,
            output_size,
465
            bias,
466
467
468
469
470
471
472
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
473

474
        self._maybe_allow_fp8_block_shape_mismatch()
475
476
477
        self.gather_output = gather_output

        assert self.quant_method is not None
478
479
        self.quant_method.create_weights(
            layer=self,
480
            input_size_per_partition=self.input_size_per_partition,
481
482
483
484
            output_partition_sizes=self.output_partition_sizes,
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
485
            weight_loader=(
486
487
488
489
490
                self.weight_loader_v2
                if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
                else self.weight_loader
            ),
        )
491

492
493
        if bias:
            self.bias = Parameter(
494
495
496
497
498
499
500
501
502
                torch.empty(self.output_size_per_partition, dtype=params_dtype)
            )
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
503
504
        else:
            self.register_parameter("bias", None)
505
        self.update_param_tp_status()
506

507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
    def _maybe_allow_fp8_block_shape_mismatch(self) -> None:
        quant_config = getattr(self, "quant_config", None)
        weight_block = getattr(quant_config, "weight_block_size", None)
        if (
            weight_block is None
            or len(weight_block) < 1
            or len(self.output_partition_sizes) <= 1
        ):
            return

        try:
            block_n = int(weight_block[0])
        except (ValueError, TypeError):
            return

        if block_n <= 0:
            return

        if any(size % block_n != 0 for size in self.output_partition_sizes):
            self.allow_fp8_block_shape_mismatch = True
            logger.debug(
                "Allowing FP8 block shape mismatch for %s (block_n=%d, partitions=%s)",
                getattr(self, "prefix", "<unknown>"),
                block_n,
                self.output_partition_sizes,
            )

534
535
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        output_dim = getattr(param, "output_dim", None)
536

537
538
539
540
541
542
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

543
544
545
546
547
548
549
550
        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
551
552
            final_shape = list(loaded_weight.shape)
            if output_dim is not None:
553
                assert final_shape[output_dim] % self.tp_size == 0
554
                final_shape[output_dim] = final_shape[output_dim] // self.tp_size
555
            param.materialize(final_shape, dtype=loaded_weight.dtype)
556

557
        param_data = param.data
558
        if output_dim is not None and not is_sharded_weight:
559
            shard_size = param_data.shape[output_dim]
560
            start_idx = self.tp_rank * shard_size
561
            loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
562
563
564
565
566

        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)
567

568
569
570
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

571
    def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor):
572
573
574
575
576
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)
577
578
        param.load_column_parallel_weight(loaded_weight=loaded_weight)

579
    def forward(
580
581
        self,
        input_,
582
    ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
583
584
585
        bias = self.bias if not self.skip_bias_add else None

        # Matrix multiply.
586
        assert self.quant_method is not None
587
        output_parallel = self.quant_method.apply(self, input_, bias)
588

589
        if self.gather_output and self.tp_size > 1:
590
591
592
593
            # All-gather across the partitions.
            output = tensor_model_parallel_all_gather(output_parallel)
        else:
            output = output_parallel
594

595
596
        if not self.return_bias:
            return output
597
        output_bias = self.bias if self.skip_bias_add else None
598
599
        return output, output_bias

600
601
602
603
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size_per_partition}"
        s += f", bias={self.bias is not None}"
604
        s += f", tp_size={self.tp_size}"
605
606
607
        s += f", gather_output={self.gather_output}"
        return s

608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626

class MergedColumnParallelLinear(ColumnParallelLinear):
    """Packed linear layers with column parallelism.

    Similar to ColumnParallelLinear, but the weight matrix is concatenated
    along the output dimension. When the weight matrix is loaded, the
    different partitions are sharded separately.

    Args:
        input_size: input dimension of the linear layer.
        output_sizes: list of output dimensions of the linear layer.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make the output
                       available to all GPUs, otherwise, every GPU will have
                       its own output.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
627
        quant_config: Quantization configure.
628
629
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
630
        return_bias: If true, return bias together with outputs in forward pass.
631
632
        disable_tp: If true, all weights matrix won't be sharded, this layer
                    will be treated as a "Replicated" MergedLinear.
633
634
    """

635
636
637
638
639
640
641
    def __init__(
        self,
        input_size: int,
        output_sizes: list[int],
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
642
643
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
644
645
646
        prefix: str = "",
        *,
        return_bias: bool = True,
647
        disable_tp: bool = False,
648
    ):
649
        self.output_sizes = output_sizes
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0

        assert all(output_size % self.tp_size == 0 for output_size in output_sizes)
        super().__init__(
            input_size=input_size,
            output_size=sum(output_sizes),
            bias=bias,
            gather_output=gather_output,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            quant_config=quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
James Fleming's avatar
James Fleming committed
666

667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
    def validate_shard_id(self, loaded_shard_id: int | tuple[int, ...] | None):
        if loaded_shard_id is None:
            return
        if isinstance(loaded_shard_id, tuple):
            for idx in loaded_shard_id:
                if not (0 <= idx < len(self.output_sizes)):
                    raise ValueError(
                        f"Shard id index {idx} should be between 0 and "
                        f"{len(self.output_sizes) - 1}. Got shard id {loaded_shard_id}."
                    )
            if len(loaded_shard_id) > 1 and any(
                b - a != 1 for a, b in zip(loaded_shard_id[:-1], loaded_shard_id[1:])
            ):
                raise ValueError(
                    "Shard id with multiple indices should be consecutive. "
                    f"Got shard id {loaded_shard_id}."
                )
            return
        elif isinstance(loaded_shard_id, int):
            if loaded_shard_id < 0 or loaded_shard_id >= len(self.output_sizes):
                raise ValueError(
                    f"Shard id should be between 0 and {len(self.output_sizes) - 1}. "
                    f"Got shard id {loaded_shard_id}."
                )
            return
        raise ValueError("This line should not be reached")

694
695
696
697
    def weight_loader(
        self,
        param: Parameter,
        loaded_weight: torch.Tensor,
698
        loaded_shard_id: tuple[int, ...] | int | None = None,
699
    ):
700
        self.validate_shard_id(loaded_shard_id)
701
702
703
704
        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
705
706
707
708
709
710
        if isinstance(loaded_shard_id, tuple) and (
            is_gguf_weight or is_gguf_weight_type
        ):
            raise NotImplementedError(
                "Shard id with multiple indices is not supported for GGUF."
            )
711
        if is_gguf_weight_type:
712
713
714
715
716
            if loaded_shard_id is not None:
                param.data[loaded_shard_id].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
                param.shard_weight_type = {
717
                    i: loaded_weight.item() for i, _ in enumerate(self.output_sizes)
718
                }
719
720
            return

721
722
        if is_gguf_weight:
            output_dim = getattr(param, "output_dim", None)
723
724
            shard_size = loaded_weight.size(output_dim) // self.tp_size
            start_idx = self.tp_rank * shard_size
725

726
            if loaded_shard_id is not None:
727
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
728
729
730
731
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
732

733
734
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
735
736
        # Special case for per-tensor scale to load scalar into fused array.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
737

738
        if loaded_shard_id is None or isinstance(loaded_shard_id, tuple):
739
740
            # Loaded weight is already fused on disk (mlp).
            # (e.g., Phi-3's gate_up_proj).
741
            if output_dim is None:
742
                if needs_scalar_to_array:
743
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
744
745
                        param_data, loaded_weight, 0
                    )
746

747
748
749
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
750
751
752
753
754
755

            output_sizes = (
                self.output_sizes[loaded_shard_id[0] : loaded_shard_id[-1] + 1]
                if loaded_shard_id is not None
                else self.output_sizes
            )
756
            current_shard_offset = 0
757
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
758
759
760
761
762
            if (
                use_bitsandbytes_4bit
                and isinstance(loaded_shard_id, tuple)
                and self.tp_size > 1
            ):
763
764
                raise NotImplementedError(
                    "Shard id with multiple indices is not supported "
765
                    "for BNB quantization with TP yet."
766
                )
767
            shard_offsets: list[tuple[int, int, int]] = []
768
            for i, output_size in enumerate(output_sizes):
769
770
771
772
                shard_offsets.append((i, current_shard_offset, output_size))
                current_shard_offset += output_size
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
773
                # Special case for Quantization.
774
775
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
776
777
778
779
780
781
782
                # Add check to adjust the size/offset for FP8 block scales
                if isinstance(param, BlockQuantScaleParameter):
                    weight_block_size = getattr(self, "weight_block_size", None)
                    shard_size, shard_offset = adjust_block_scale_shard(
                        weight_block_size, shard_size, shard_offset
                    )

783
                if packed_dim == output_dim:
784
785
                    shard_size = shard_size // param.packed_factor
                    shard_offset = shard_offset // param.packed_factor
786
                    # Special case for Marlin.
787
                    shard_size, shard_offset = adjust_marlin_shard(
788
789
                        param, shard_size, shard_offset
                    )
790

791
                if use_bitsandbytes_4bit:
792
793
794
795
796
797
798
                    index = list(itertools.accumulate([0] + self.output_sizes))
                    orig_offsets = {
                        str(i): (index[i], size)
                        for i, size in enumerate(self.output_sizes)
                    }
                    orig_offsets["total"] = (self.output_size, 0)
                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
799
800
                        param, orig_offsets, str(shard_id)
                    )
801

802
                loaded_weight_shard = loaded_weight.narrow(
803
804
                    output_dim, shard_offset, shard_size
                )
805
806
807
808
809
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id < len(self.output_sizes)
        if output_dim is not None:
810
811
            shard_offset = sum(self.output_sizes[:loaded_shard_id])
            shard_size = self.output_sizes[loaded_shard_id]
812
813
            shard_offset //= self.tp_size
            shard_size //= self.tp_size
814
815
816
817
818
819
820

            if isinstance(param, BlockQuantScaleParameter):
                weight_block_size = getattr(self, "weight_block_size", None)
                shard_size, shard_offset = adjust_block_scale_shard(
                    weight_block_size, shard_size, shard_offset
                )

821
            # Special case for quantization.
822
823
824
825
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
826
827
                shard_size = round(shard_size // param.packed_factor)
                shard_offset = round(shard_offset // param.packed_factor)
828
                # Special case for Marlin.
829
                shard_size, shard_offset = adjust_marlin_shard(
830
831
                    param, shard_size, shard_offset
                )
832

833
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
834
835
836
837
838
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

839
            if use_bitsandbytes_4bit:
840
841
842
843
844
845
846
847
                index = list(itertools.accumulate([0] + self.output_sizes))
                orig_offsets = {
                    str(i): (index[i], size) for i, size in enumerate(self.output_sizes)
                }
                orig_offsets["total"] = (self.output_size, 0)
                shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
                    param, orig_offsets, str(loaded_shard_id)
                )
848
            param_data = param_data.narrow(output_dim, shard_offset, shard_size)
849
            start_idx = self.tp_rank * shard_size
850
            if not is_sharded_weight:
851
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
852
853
854
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
855
856
                param_data, loaded_weight, loaded_shard_id
            )
857

858
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
859
860
861
862
863
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "MergedColumnParallelLinear, assume the weight is "
864
865
                    "the same for all partitions."
                )
866

867
868
869
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

870
    def _load_fused_module_from_checkpoint(
871
872
873
874
        self,
        param: BasevLLMParameter,
        loaded_weight: torch.Tensor,
        output_sizes: list[int] | None = None,
875
    ):
876
877
878
        """
        Handle special case for models where MLP layers are already
        fused on disk. In this case, we have no shard id. This function
879
        determines the shard id by splitting these layers and then calls
880
881
882
883
884
885
886
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """

        current_shard_offset = 0
887
        shard_offsets: list[tuple[int, int, int]] = []
888
889
        output_sizes = output_sizes or self.output_sizes
        for i, output_size in enumerate(output_sizes):
890
891
892
893
894
895
896
            shard_offsets.append((i, current_shard_offset, output_size))
            current_shard_offset += output_size

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
897
898
899
900
901
902
903
904
905
906
907
            if (
                isinstance(param, (PackedColumnParameter, PackedvLLMParameter))
                and param.packed_dim == param.output_dim
            ):
                shard_size, shard_offset = param.adjust_shard_indexes_for_packing(
                    shard_size=shard_size, shard_offset=shard_offset
                )

            loaded_weight_shard = loaded_weight.narrow(
                param.output_dim, shard_offset, shard_size
            )
908
909
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

910
911
912
913
    def weight_loader_v2(
        self,
        param: BasevLLMParameter,
        loaded_weight: torch.Tensor,
914
        loaded_shard_id: tuple[int, ...] | int | None = None,
915
    ):
916
        self.validate_shard_id(loaded_shard_id)
917
        if loaded_shard_id is None or isinstance(loaded_shard_id, tuple):
918
            if isinstance(param, PerTensorScaleParameter):
919
920
921
922
923
924
                if isinstance(loaded_shard_id, tuple):
                    for idx in loaded_shard_id:
                        param.load_merged_column_weight(
                            loaded_weight=loaded_weight, shard_id=idx
                        )
                else:
925
926
927
928
929
930
931
932
933
                    # When weights are already fused on disk (e.g. Phi-3's
                    # gate_up_proj), there is only a single scale for the
                    # entire fused matrix. Fill all slots with this scale
                    # to ensure that any subsequent reduction (like .max())
                    # works correctly while preserving the parameter shape.
                    for idx in range(param.data.shape[0]):
                        param.load_merged_column_weight(
                            loaded_weight=loaded_weight, shard_id=idx
                        )
934
                return
935
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
936
                param.load_merged_column_weight(loaded_weight=loaded_weight)
937
                return
938
939
940
941
942
943
944
945
946
947
948
            output_sizes = (
                [self.output_sizes[idx] for idx in loaded_shard_id]
                if loaded_shard_id
                else None
            )
            if isinstance(param, BlockQuantScaleParameter):
                weight_block_size = getattr(self, "weight_block_size", None)
                output_sizes = [
                    adjust_block_scale_shard(weight_block_size, size, 0)[0]
                    for size in (output_sizes or self.output_sizes)
                ]
949
            # TODO: @dsikka - move to parameter.py
950
951
952
            self._load_fused_module_from_checkpoint(
                param, loaded_weight, output_sizes=output_sizes
            )
953
954
955
956
            return

        assert loaded_shard_id < len(self.output_sizes)

957
958
        shard_offset = sum(self.output_sizes[:loaded_shard_id])
        shard_size = self.output_sizes[loaded_shard_id]
959
960
        shard_offset //= self.tp_size
        shard_size //= self.tp_size
961

962
        if isinstance(param, BlockQuantScaleParameter):
963
964
965
            weight_block_size = getattr(self, "weight_block_size", None)
            shard_size, shard_offset = adjust_block_scale_shard(
                weight_block_size, shard_size, shard_offset
966
            )
967

968
969
970
971
972
973
974
        param.load_merged_column_weight(
            loaded_weight=loaded_weight,
            shard_id=loaded_shard_id,
            shard_offset=shard_offset,
            shard_size=shard_size,
            tp_rank=self.tp_rank,
        )
975

976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997

class QKVParallelLinear(ColumnParallelLinear):
    """Linear layers for the attention's QKV transformation.

    Linear layers for the linear transformation of the query, key, and value
    vectors in the attention layer. The weight matrix is concatenated along
    the output dimension. The layer is parallelized along the head dimension.
    When the number of key/value heads is smaller than the number of query
    heads (e.g., multi-query/grouped-query attention), the key/value head may
    be replicated while the query heads are partitioned.

    Args:
        hidden_size: input hidden state size of the transformer.
        head_size: size of each attention head.
        total_num_heads: total number of attention query heads.
        total_num_kv_heads: total number of attention key/value heads. If
                            None, assume total_num_kv_heads = total_num_heads.
        bias: If true, add bias.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
998
        quant_config: Quantization configure.
999
1000
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
1001
        return_bias: If true, return bias together with outputs in forward pass.
1002
        disable_tp: If true, weights matrix won't be sharded through tp rank.
1003
1004
    """

1005
1006
1007
1008
1009
    def __init__(
        self,
        hidden_size: int,
        head_size: int,
        total_num_heads: int,
1010
        total_num_kv_heads: int | None = None,
1011
1012
        bias: bool = True,
        skip_bias_add: bool = False,
1013
1014
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
1015
1016
1017
        prefix: str = "",
        *,
        return_bias: bool = True,
1018
        disable_tp: bool = False,
1019
        v_head_size: int | None = None,
1020
    ):
1021
1022
        self.hidden_size = hidden_size
        self.head_size = head_size
1023
        self.v_head_size = v_head_size if v_head_size is not None else head_size
1024
1025
1026
1027
1028
        self.total_num_heads = total_num_heads
        if total_num_kv_heads is None:
            total_num_kv_heads = total_num_heads
        self.total_num_kv_heads = total_num_kv_heads
        # Divide the weight matrix along the last dimension.
1029
        tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
1030
1031
1032
        self.num_heads = divide(self.total_num_heads, tp_size)
        if tp_size >= self.total_num_kv_heads:
            self.num_kv_heads = 1
1033
            self.num_kv_head_replicas = divide(tp_size, self.total_num_kv_heads)
1034
1035
1036
1037
        else:
            self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
            self.num_kv_head_replicas = 1
        input_size = self.hidden_size
1038
        output_size = (
1039
1040
1041
1042
            self.num_heads * self.head_size
            + self.num_kv_heads * self.head_size
            + self.num_kv_heads * self.v_head_size
        ) * tp_size
1043
1044
1045
        self.output_sizes = [
            self.num_heads * self.head_size * tp_size,  # q_proj
            self.num_kv_heads * self.head_size * tp_size,  # k_proj
1046
            self.num_kv_heads * self.v_head_size * tp_size,  # v_proj
James Fleming's avatar
James Fleming committed
1047
1048
        ]

1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
        super().__init__(
            input_size=input_size,
            output_size=output_size,
            bias=bias,
            gather_output=False,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            quant_config=quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
1061

1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
    def validate_shard_id(self, loaded_shard_id: str | None):
        if loaded_shard_id is None:
            return
        if isinstance(loaded_shard_id, str):
            if loaded_shard_id not in ["q", "k", "v"]:
                raise ValueError(
                    "Shard id for QKVParallelLinear should be 'q', 'k', or 'v', "
                    f"got shard id {loaded_shard_id}."
                )
            return
        raise ValueError("This line should not be reached")

1074
1075
1076
1077
1078
    def _get_shard_offset_mapping(self, loaded_shard_id: str):
        shard_offset_mapping = {
            "q": 0,
            "k": self.num_heads * self.head_size,
            "v": (self.num_heads + self.num_kv_heads) * self.head_size,
1079
1080
            "total": (self.num_heads + self.num_kv_heads) * self.head_size
            + self.num_kv_heads * self.v_head_size,
1081
1082
1083
1084
1085
1086
1087
        }
        return shard_offset_mapping.get(loaded_shard_id)

    def _get_shard_size_mapping(self, loaded_shard_id: str):
        shard_size_mapping = {
            "q": self.num_heads * self.head_size,
            "k": self.num_kv_heads * self.head_size,
1088
            "v": self.num_kv_heads * self.v_head_size,
1089
1090
1091
        }
        return shard_size_mapping.get(loaded_shard_id)

1092
1093
1094
    def _load_fused_module_from_checkpoint(
        self, param: BasevLLMParameter, loaded_weight: torch.Tensor
    ):
1095
        """
1096
        Handle special case for models where QKV layers are already
1097
        fused on disk. In this case, we have no shard id. This function
1098
        determines the shard id by splitting these layers and then calls
1099
1100
1101
1102
1103
1104
1105
1106
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """
        shard_offsets = [
            # (shard_id, shard_offset, shard_size)
            ("q", 0, self.total_num_heads * self.head_size),
1107
1108
1109
1110
1111
1112
1113
1114
            (
                "k",
                self.total_num_heads * self.head_size,
                self.total_num_kv_heads * self.head_size,
            ),
            (
                "v",
                (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
1115
                self.total_num_kv_heads * self.v_head_size,
1116
            ),
1117
1118
1119
1120
1121
1122
        ]

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
            if (
                isinstance(param, (PackedColumnParameter, PackedvLLMParameter))
                and param.packed_dim == param.output_dim
            ):
                shard_size, shard_offset = param.adjust_shard_indexes_for_packing(
                    shard_size=shard_size, shard_offset=shard_offset
                )

            loaded_weight_shard = loaded_weight.narrow(
                param.output_dim, shard_offset, shard_size
            )
1134
1135
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

1136
1137
1138
1139
    def weight_loader_v2(
        self,
        param: BasevLLMParameter,
        loaded_weight: torch.Tensor,
1140
        loaded_shard_id: str | None = None,
1141
    ):
1142
        self.validate_shard_id(loaded_shard_id)
1143
        if loaded_shard_id is None:  # special case for certain models
1144
            if isinstance(param, PerTensorScaleParameter):
1145
1146
1147
1148
1149
1150
1151
1152
1153
                # When weights are already fused on disk (e.g. Phi-3's
                # qkv_proj), there is only a single scale for the entire
                # fused matrix. Fill all slots (q, k, v) with this scale
                # to ensure that any subsequent reduction (like .max())
                # works correctly while preserving the parameter shape.
                for idx in range(param.data.shape[0]):
                    param.load_qkv_weight(
                        loaded_weight=loaded_weight, shard_id=idx, tp_rank=self.tp_rank
                    )
1154
                return
1155
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
1156
                param.load_qkv_weight(loaded_weight=loaded_weight, tp_rank=self.tp_rank)
1157
                return
1158
            # TODO: @dsikka - move to parameter.py
1159
1160
1161
1162
1163
1164
1165
1166
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id in ["q", "k", "v"]

        shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
        shard_size = self._get_shard_size_mapping(loaded_shard_id)

1167
        if isinstance(param, BlockQuantScaleParameter):
1168
1169
1170
1171
            weight_block_size = getattr(self, "weight_block_size", None)
            shard_size, shard_offset = adjust_block_scale_shard(
                weight_block_size, shard_size, shard_offset
            )
1172

1173
1174
1175
1176
1177
1178
1179
1180
        param.load_qkv_weight(
            loaded_weight=loaded_weight,
            num_heads=self.num_kv_head_replicas,
            shard_id=loaded_shard_id,
            shard_offset=shard_offset,
            shard_size=shard_size,
            tp_rank=self.tp_rank,
        )
1181

1182
1183
1184
1185
    def weight_loader(
        self,
        param: Parameter,
        loaded_weight: torch.Tensor,
1186
        loaded_shard_id: str | None = None,
1187
    ):
1188
        self.validate_shard_id(loaded_shard_id)
1189
1190
1191
1192
        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
1193
        if is_gguf_weight_type:
1194
            idx_map = {"q": 0, "k": 1, "v": 2}
1195
1196
1197
1198
            if loaded_shard_id is not None:
                param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
1199
                param.shard_weight_type = {k: loaded_weight.item() for k in idx_map}
1200
1201
            return

1202
1203
        if is_gguf_weight:
            output_dim = getattr(param, "output_dim", None)
1204
1205
            shard_size = loaded_weight.size(output_dim) // self.tp_size
            start_idx = self.tp_rank * shard_size
1206

1207
            if loaded_shard_id is not None:
1208
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
1209
1210
1211
1212
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
1213

1214
1215
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
1216

1217
1218
        # Special case for per-tensor scales in fused case.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
1219

1220
        if loaded_shard_id is None:
1221
1222
            # Loaded weight is already fused on disk (qkv).
            # (e.g., Phi-3's qkv_proj).
1223
            if output_dim is None:
1224
                if needs_scalar_to_array:
1225
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
1226
1227
                        param_data, loaded_weight, 0
                    )
1228

1229
1230
1231
1232
1233
1234
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            shard_offsets = [
                # (shard_id, shard_offset, shard_size)
                ("q", 0, self.total_num_heads * self.head_size),
1235
1236
1237
1238
1239
1240
1241
1242
                (
                    "k",
                    self.total_num_heads * self.head_size,
                    self.total_num_kv_heads * self.head_size,
                ),
                (
                    "v",
                    (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
1243
                    self.total_num_kv_heads * self.v_head_size,
1244
                ),
1245
            ]
1246
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1247

1248
1249
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
1250
                # Special case for Quantized Weights.
1251
1252
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
1253
1254
1255
1256
1257
1258
1259
                # Add check to adjust the size/offset for FP8 block scales
                if isinstance(param, BlockQuantScaleParameter):
                    weight_block_size = getattr(self, "weight_block_size", None)
                    shard_size, shard_offset = adjust_block_scale_shard(
                        weight_block_size, shard_size, shard_offset
                    )

1260
                if packed_dim == output_dim:
1261
1262
                    shard_size = round(shard_size // param.packed_factor)
                    shard_offset = round(shard_offset // param.packed_factor)
1263

1264
                    # Special case for Marlin.
1265
                    shard_size, shard_offset = adjust_marlin_shard(
1266
1267
                        param, shard_size, shard_offset
                    )
1268

1269
1270
1271
                if use_bitsandbytes_4bit:
                    orig_qkv_offsets = {
                        "q": (0, self.total_num_heads * self.head_size),
1272
1273
1274
1275
1276
1277
1278
                        "k": (
                            self.total_num_heads * self.head_size,
                            self.total_num_kv_heads * self.head_size,
                        ),
                        "v": (
                            (self.total_num_heads + self.total_num_kv_heads)
                            * self.head_size,
1279
                            self.total_num_kv_heads * self.v_head_size,
1280
1281
                        ),
                        "total": (
1282
1283
1284
                            (self.total_num_heads + self.total_num_kv_heads)
                            * self.head_size
                            + self.total_num_kv_heads * self.v_head_size,
1285
1286
                            0,
                        ),
1287
1288
1289
                    }

                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
1290
1291
                        param, orig_qkv_offsets, shard_id
                    )
1292

1293
                loaded_weight_shard = loaded_weight.narrow(
1294
1295
                    output_dim, shard_offset, shard_size
                )
1296
1297
1298
1299
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id in ["q", "k", "v"]
1300
1301

        # If output dim is defined, use the default loading process.
1302
1303
1304
1305
1306
1307
1308
1309
        if output_dim is not None:
            if loaded_shard_id == "q":
                shard_offset = 0
                shard_size = self.num_heads * self.head_size
            elif loaded_shard_id == "k":
                shard_offset = self.num_heads * self.head_size
                shard_size = self.num_kv_heads * self.head_size
            elif loaded_shard_id == "v":
1310
                shard_offset = (self.num_heads + self.num_kv_heads) * self.head_size
1311
                shard_size = self.num_kv_heads * self.v_head_size
1312
1313
1314
1315
1316
1317
1318

            if isinstance(param, BlockQuantScaleParameter):
                weight_block_size = getattr(self, "weight_block_size", None)
                shard_size, shard_offset = adjust_block_scale_shard(
                    weight_block_size, shard_size, shard_offset
                )

1319
            # Special case for Quantized Weights.
1320
1321
1322
1323
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
1324
1325
                shard_size = round(shard_size // param.packed_factor)
                shard_offset = round(shard_offset // param.packed_factor)
1326

1327
                # Special case for Marlin.
1328
                shard_size, shard_offset = adjust_marlin_shard(
1329
1330
                    param, shard_size, shard_offset
                )
1331

1332
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1333
1334
1335
1336
1337
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

1338
            if use_bitsandbytes_4bit:
1339
1340
                orig_qkv_offsets = {
                    "q": (0, self.num_heads * self.head_size),
1341
1342
1343
1344
1345
1346
                    "k": (
                        self.num_heads * self.head_size,
                        self.num_kv_heads * self.head_size,
                    ),
                    "v": (
                        (self.num_heads + self.num_kv_heads) * self.head_size,
1347
                        self.num_kv_heads * self.v_head_size,
1348
1349
                    ),
                    "total": (
1350
1351
                        (self.num_heads + self.num_kv_heads) * self.head_size
                        + self.num_kv_heads * self.v_head_size,
1352
1353
                        0,
                    ),
1354
                }
1355
                shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
1356
1357
                    param, orig_qkv_offsets, loaded_shard_id
                )
1358

1359
            param_data = param_data.narrow(output_dim, shard_offset, shard_size)
1360
            if loaded_shard_id == "q":
1361
                shard_rank = self.tp_rank
1362
            else:
1363
1364
                shard_rank = self.tp_rank // self.num_kv_head_replicas
            start_idx = shard_rank * shard_size
1365

1366
            if not is_sharded_weight:
1367
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
1368

1369
1370
1371
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
1372
1373
                param_data, loaded_weight, loaded_shard_id
            )
1374
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
1375
1376
1377
1378
1379
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "QKVParallelLinear, assume the weight is the same "
1380
1381
                    "for all partitions."
                )
1382

1383
1384
1385
1386
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)


1387
# --8<-- [start:row_parallel_linear]
1388
@PluggableLayer.register("row_parallel_linear")
1389
class RowParallelLinear(LinearBase):
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
    """Linear layer with row parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its first dimension and X along its second dimension as:
               -   -
              | A_1 |
              | .   |
          A = | .   |        X = [X_1, ..., X_p]
              | .   |
              | A_p |
               -   -
    Arguments:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias. Note that bias is not parallelized.
        input_is_parallel: If true, we assume that the input is already
                           split across the GPUs and we do not split
                           again.
        skip_bias_add: This was added to enable performance optimization where
                       bias can be fused with other element-wise operations.
                       We skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
1412
1413
1414
        reduce_results: If true, call all-reduce on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y = X_iA_i
1415
        quant_config: Quantization configure.
1416
1417
1418
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.down_proj)
        return_bias: If true, return bias together with outputs in forward pass.
1419
        disable_tp: If true, weights matrix won't be sharded through tp rank.
1420
1421
    """

1422
1423
    # --8<-- [end:row_parallel_linear]

1424
1425
1426
1427
1428
1429
1430
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        input_is_parallel: bool = True,
        skip_bias_add: bool = False,
1431
        params_dtype: torch.dtype | None = None,
1432
        reduce_results: bool = True,
1433
        quant_config: QuantizationConfig | None = None,
1434
1435
1436
        prefix: str = "",
        *,
        return_bias: bool = True,
1437
        disable_tp: bool = False,
1438
    ):
1439
        # Divide the weight matrix along the first dimension.
1440
1441
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
1442
1443
1444
1445
        self.input_size_per_partition = divide(input_size, self.tp_size)
        self.output_size_per_partition = output_size
        self.output_partition_sizes = [output_size]

1446
1447
1448
        super().__init__(
            input_size,
            output_size,
1449
            bias,
1450
1451
1452
1453
1454
1455
1456
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
1457

1458
1459
1460
        self.input_is_parallel = input_is_parallel
        self.reduce_results = reduce_results

1461
        assert self.quant_method is not None
1462
1463
1464
        self.quant_method.create_weights(
            layer=self,
            input_size_per_partition=self.input_size_per_partition,
1465
            output_partition_sizes=self.output_partition_sizes,
1466
1467
1468
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
1469
            weight_loader=(
1470
1471
1472
1473
1474
                self.weight_loader_v2
                if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
                else self.weight_loader
            ),
        )
1475
        if not reduce_results and (bias and not skip_bias_add):
1476
1477
1478
1479
            raise ValueError(
                "When not reduce the results, adding bias to the "
                "results can lead to incorrect results"
            )
1480
1481

        if bias:
1482
1483
1484
1485
1486
1487
1488
1489
            self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype))
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
1490
1491
        else:
            self.register_parameter("bias", None)
1492
        self.update_param_tp_status()
1493
1494
1495

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        input_dim = getattr(param, "input_dim", None)
1496
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1497
1498
1499
1500
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
1501
1502
1503
1504
1505
1506
1507
1508
1509

        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
1510
1511
            weight_shape = list(loaded_weight.shape)
            if input_dim:
1512
                weight_shape[input_dim] = weight_shape[input_dim] // self.tp_size
1513
            param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
1514

1515
        param_data = param.data
1516
        if input_dim is not None and not is_sharded_weight:
1517
            shard_size = param_data.shape[input_dim]
1518
            start_idx = self.tp_rank * shard_size
1519
            loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)
1520

1521
1522
1523
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
1524
1525
            loaded_weight = loaded_weight.reshape(1)

1526
1527
1528
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

1529
    def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor):
1530
1531
1532
1533
1534
1535
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)

1536
1537
        param.load_row_parallel_weight(loaded_weight=loaded_weight)

1538
    def forward(
1539
1540
        self,
        input_,
1541
    ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
1542
1543
1544
        if self.input_is_parallel:
            input_parallel = input_
        else:
Jiayi Yan's avatar
Jiayi Yan committed
1545
            split_input = split_tensor_along_last_dim(
1546
1547
                input_, num_partitions=self.tp_size
            )
Jiayi Yan's avatar
Jiayi Yan committed
1548
            input_parallel = split_input[self.tp_rank].contiguous()
1549
1550

        # Matrix multiply.
1551
        assert self.quant_method is not None
1552
1553
1554
        # Only fuse bias add into GEMM for rank 0 (this ensures that
        # bias will not get added more than once in TP>1 case)
        bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
1555
1556
        output_parallel = self.quant_method.apply(self, input_parallel, bias_)

1557
        if self.reduce_results and self.tp_size > 1:
1558
            output = tensor_model_parallel_all_reduce(output_parallel)
1559
        else:
1560
1561
            output = output_parallel

1562
1563
        if not self.return_bias:
            return output
1564
        output_bias = self.bias if self.skip_bias_add else None
1565
        return output, output_bias
1566
1567

    def extra_repr(self) -> str:
1568
        s = f"in_features={self.input_size_per_partition}"
1569
1570
1571
1572
1573
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={self.tp_size}"
        s += f", reduce_results={self.reduce_results}"
        return s