linear.py 60.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
from abc import abstractmethod
6
7

import torch
8
from torch.nn.parameter import Parameter, UninitializedParameter
9

10
import vllm.envs as envs
11
12
13
14
15
16
17
18
from vllm.distributed import (
    divide,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
    split_tensor_along_last_dim,
    tensor_model_parallel_all_gather,
    tensor_model_parallel_all_reduce,
)
19
from vllm.logger import init_logger
20
from vllm.model_executor.custom_op import PluggableLayer
21
22
23
from vllm.model_executor.layers.batch_invariant import (
    linear_batch_invariant,
)
24
from vllm.model_executor.layers.quantization.base_config import (
25
26
27
    QuantizationConfig,
    QuantizeMethodBase,
)
28
29
30
from vllm.model_executor.layers.utils import (
    dispatch_unquantized_gemm,
)
31
32
33
34
35
36
37
38
39
from vllm.model_executor.parameter import (
    BasevLLMParameter,
    BlockQuantScaleParameter,
    ModelWeightParameter,
    PackedColumnParameter,
    PackedvLLMParameter,
    PerTensorScaleParameter,
    RowvLLMParameter,
)
40
from vllm.model_executor.utils import set_weight_attrs
41
from vllm.platforms import current_platform
42
43
44

logger = init_logger(__name__)

45
WEIGHT_LOADER_V2_SUPPORTED = [
46
    "UnquantizedLinearMethod",
47
    "CompressedTensorsLinearMethod",
48
    "CompressedTensorsLinearTransformMethod",
49
50
51
52
53
54
55
56
57
58
    "AWQMarlinLinearMethod",
    "AWQLinearMethod",
    "GPTQMarlinLinearMethod",
    "Fp8LinearMethod",
    "MarlinLinearMethod",
    "GPTQMarlin24LinearMethod",
    "TPUInt8LinearMethod",
    "GPTQLinearMethod",
    "FBGEMMFp8LinearMethod",
    "ModelOptFp8LinearMethod",
59
60
    "ModelOptFp8PcPtLinearMethod",
    "ModelOptFp8PbWoLinearMethod",
61
62
    "QuarkLinearMethod",
    "ModelOptNvFp4LinearMethod",
63
]
64

65

66
67
68
69
70
71
def register_weight_loader_v2_supported_method(cls):
    """Decorator to register a LinearMethod as supporting weight_loader_v2."""
    WEIGHT_LOADER_V2_SUPPORTED.append(cls.__name__)
    return cls


72
73
74
75
76
77
def adjust_marlin_shard(
    param: Parameter,
    shard_size: int,
    shard_offset: int,
) -> tuple[int, int]:
    marlin_tile_size: int | None = getattr(param, "marlin_tile_size", None)
78
79
80
81
82
83
    if marlin_tile_size is None:
        return shard_size, shard_offset

    return shard_size * marlin_tile_size, shard_offset * marlin_tile_size


84
85
86
87
88
def adjust_block_scale_shard(
    weight_block_size: tuple[int, ...] | None,
    shard_size: int,
    shard_offset: int,
) -> tuple[int, int]:
89
90
91
92
93
94
95
    assert weight_block_size is not None
    block_n = weight_block_size[0]
    shard_offset = (shard_offset + block_n - 1) // block_n
    shard_size = (shard_size + block_n - 1) // block_n
    return shard_size, shard_offset


96
def adjust_bitsandbytes_4bit_shard(
97
98
99
    param: Parameter,
    shard_offsets: dict[str, tuple[int, int]],
    loaded_shard_id: str,
100
) -> tuple[int, int]:
101
102
    """Adjust the quantization offsets and sizes for BitsAndBytes sharding."""

103
104
    total, _ = shard_offsets["total"]
    orig_offset, orig_size = shard_offsets[loaded_shard_id]
105
106
107
108
109
110
111
112

    quantized_total = param.data.shape[0]
    quantized_offset = orig_offset * quantized_total // total
    quantized_size = orig_size * quantized_total // total

    return quantized_size, quantized_offset


113
114
115
116
117
def adjust_scalar_to_fused_array(
    param_data: torch.Tensor,
    loaded_weight: torch.Tensor,
    shard_id: int | str,
) -> tuple[torch.Tensor, torch.Tensor]:
118
119
    """For fused modules (QKV and MLP) we have an array of length
    N that holds 1 scale for each "logical" matrix. So the param
120
121
    is an array of length N. The loaded_weight corresponds to
    one of the shards on disk. Here, we slice the param based on
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
    the shard_id for loading.
    """
    qkv_idxs = {"q": 0, "k": 1, "v": 2}

    if isinstance(shard_id, str):
        shard_id = qkv_idxs[shard_id]
    elif not isinstance(shard_id, int):
        raise ValueError(f"Unknown Shard Id {shard_id}")

    # AutoFP8 scales do not have a shape
    # compressed-tensors scales do have a shape
    if len(loaded_weight.shape) != 0:
        assert loaded_weight.shape[0] == 1
        loaded_weight = loaded_weight[0]

137
    return param_data[shard_id], loaded_weight
138
139


140
class LinearMethodBase(QuantizeMethodBase):
141
142
143
    """Base class for different (maybe quantized) linear methods."""

    @abstractmethod
144
145
146
147
148
149
150
151
152
153
154
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        """Create weights for a linear layer.
155
           The weights will be set as attributes of the layer.
156

157
158
159
        Args:
            layer: The layer that is using the LinearMethodBase factory.
            input_size_per_partition: Size of the weight input dim on rank X.
160
            output_partition_sizes: Sizes of the output dim of each logical
161
162
163
164
165
166
                weight on rank X. E.g., output_partition_sizes for QKVLinear
                is a list contains the width of Wq, Wk, Wv on rank X.
            input_size: Size of the input dim of the weight across all ranks.
            output_size: Size of the output dim of the weight across all ranks.
            params_dtype: Datatype of the parameters.
        """
167
168
169
        raise NotImplementedError

    @abstractmethod
170
171
172
173
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
174
        bias: torch.Tensor | None = None,
175
    ) -> torch.Tensor:
176
177
        """Apply the weights in layer to the input tensor.
        Expects create_weights to have been called before on the layer."""
178
179
180
181
        raise NotImplementedError


class UnquantizedLinearMethod(LinearMethodBase):
182
    """Linear method without quantization."""
183

184
185
186
187
188
189
190
191
192
193
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
194
195
196
197
        # This method creates unquantized linear weights.
        # The weights are not quantized, and they are not sharded.
        # The amount of memory allocated for the weights is
        # sum(output_partition_sizes) * input_size_per_partition.
198
199
200
201
202
203
204
205
206
207
208
        weight_loader = extra_weight_attrs.pop("weight_loader")
        weight = ModelWeightParameter(
            data=torch.empty(
                sum(output_partition_sizes),
                input_size_per_partition,
                dtype=params_dtype,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
209

210
211
        layer.register_parameter("weight", weight)
        set_weight_attrs(weight, extra_weight_attrs)
212

213
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
214
        if current_platform.is_cpu():
215
            from vllm.model_executor.layers.utils import dispatch_cpu_unquantized_gemm
216

217
            dispatch_cpu_unquantized_gemm(layer, remove_weight=True)
218

219
220
221
222
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
223
        bias: torch.Tensor | None = None,
224
    ) -> torch.Tensor:
225
        if envs.VLLM_BATCH_INVARIANT and current_platform.is_cuda_alike():
226
            return linear_batch_invariant(x, layer.weight, bias)
227
        return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
228
229


230
class LinearBase(PluggableLayer):
231
    """Base linear layer.
232
233
234
235
236
237

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
238
        quant_config: Quantization configure.
239
        prefix: Prefix for parameter names.
240
        return_bias: If true, return bias together with outputs in forward pass.
241
        disable_tp: If true, tensor parallelism will be disabled for this layer.
242
243
244
245
246
247
248
    """

    def __init__(
        self,
        input_size: int,
        output_size: int,
        skip_bias_add: bool = False,
249
250
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
251
        prefix: str = "",
252
253
        *,
        return_bias: bool = True,
254
        disable_tp: bool = False,
255
256
257
258
259
260
261
262
263
264
    ):
        super().__init__()

        # Keep input parameters
        self.input_size = input_size
        self.output_size = output_size
        self.skip_bias_add = skip_bias_add
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype
265
266
        self.quant_config = quant_config
        self.prefix = prefix
267
        self.allow_fp8_block_shape_mismatch = False
268
        if quant_config is None:
269
            self.quant_method: QuantizeMethodBase | None = UnquantizedLinearMethod()
270
        else:
271
            self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
272
        self.return_bias = return_bias
273
        self.disable_tp = disable_tp
274
275
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
276

277
    def update_param_tp_status(self):
278
279
280
281
        for param in self.parameters():
            if isinstance(param, BasevLLMParameter):
                param.tp_rank = self.tp_rank
                param.tp_size = self.tp_size
282
283


284
# --8<-- [start:replicated_linear]
285
@PluggableLayer.register("replicated_linear")
286
287
288
289
290
291
292
293
294
295
class ReplicatedLinear(LinearBase):
    """Replicated linear layer.

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        bias: If true, add bias.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
        quant_config: Quantization configure.
296
297
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
298
        return_bias: If true, return bias together with outputs in forward pass.
299
        disable_tp: Take no effect for replicated linear layers.
300
301
    """

302
303
    # --8<-- [end:replicated_linear]

304
305
306
307
308
309
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        skip_bias_add: bool = False,
310
311
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
312
313
314
        prefix: str = "",
        *,
        return_bias: bool = True,
315
        disable_tp: bool = False,
316
    ):
317
318
319
320
321
322
        # If MergedReplicatedLinear, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = self.output_sizes
        else:
            self.output_partition_sizes = [output_size]

323
324
325
326
327
328
329
330
331
332
        super().__init__(
            input_size,
            output_size,
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
333

334
335
        # All the linear layer supports quant method.
        assert self.quant_method is not None
336
337
338
339
340
341
342
343
344
        self.quant_method.create_weights(
            self,
            self.input_size,
            self.output_partition_sizes,
            self.input_size,
            self.output_size,
            self.params_dtype,
            weight_loader=self.weight_loader,
        )
345

346
347
        if bias:
            self.bias = Parameter(
348
349
350
351
352
353
354
355
356
                torch.empty(self.output_size, dtype=self.params_dtype)
            )
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
357
358
359
        else:
            self.register_parameter("bias", None)

360
361
362
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        # If the weight on disk does not have a shape, give it one
        # (such scales for AutoFp8).
363
364
365
366
367
368
369
370
371
372
373
        # Special case for GGUF

        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
            param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)

374
375
376
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)

377
378
        assert param.size() == loaded_weight.size(), (
            f"Tried to load weights of size {loaded_weight.size()}"
379
380
            f"to a parameter of size {param.size()}"
        )
381
382
        param.data.copy_(loaded_weight)

383
    def forward(
384
385
        self,
        x: torch.Tensor,
386
    ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
387
        bias = self.bias if not self.skip_bias_add else None
388
        assert self.quant_method is not None
389

390
        output = self.quant_method.apply(self, x, bias)
391

392
393
        if not self.return_bias:
            return output
394
        output_bias = self.bias if self.skip_bias_add else None
395
396
        return output, output_bias

397
398
399
400
401
402
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        return s

403

404
# --8<-- [start:column_parallel_linear]
405
@PluggableLayer.register("column_parallel_linear")
406
class ColumnParallelLinear(LinearBase):
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
    """Linear layer with column parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its second dimension as A = [A_1, ..., A_p].

    Args:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y_i = XA_i
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
423
        quant_config: Quantization configure.
424
        prefix: The name of the layer in the state dict, including all parents
425
                        (e.g. model.layers.0.qkv_proj)
426
427
        return_bias: If true, return bias together with outputs in forward pass.
        disable_tp: If true, weights matrix won't be sharded through tp rank.
428
429
    """

430
431
    # --8<-- [end:column_parallel_linear]

432
433
434
435
436
437
438
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
439
440
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
441
442
443
        prefix: str = "",
        *,
        return_bias: bool = True,
444
        disable_tp: bool = False,
445
    ):
446
        # Divide the weight matrix along the last dimension.
447
448
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
449
450
        self.input_size_per_partition = input_size
        self.output_size_per_partition = divide(output_size, self.tp_size)
451
452
453
454
        self.output_partition_sizes = [self.output_size_per_partition]
        # If QKV or MergedColumn, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = [
455
                divide(output_size, self.tp_size) for output_size in self.output_sizes
456
457
            ]

458
459
460
461
462
463
464
465
466
467
        super().__init__(
            input_size,
            output_size,
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
468

469
        self._maybe_allow_fp8_block_shape_mismatch()
470
471
472
        self.gather_output = gather_output

        assert self.quant_method is not None
473
474
        self.quant_method.create_weights(
            layer=self,
475
            input_size_per_partition=self.input_size_per_partition,
476
477
478
479
            output_partition_sizes=self.output_partition_sizes,
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
480
            weight_loader=(
481
482
483
484
485
                self.weight_loader_v2
                if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
                else self.weight_loader
            ),
        )
486
487
        if bias:
            self.bias = Parameter(
488
489
490
491
492
493
494
495
496
                torch.empty(self.output_size_per_partition, dtype=params_dtype)
            )
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
497
498
        else:
            self.register_parameter("bias", None)
499
        self.update_param_tp_status()
500

501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
    def _maybe_allow_fp8_block_shape_mismatch(self) -> None:
        quant_config = getattr(self, "quant_config", None)
        weight_block = getattr(quant_config, "weight_block_size", None)
        if (
            weight_block is None
            or len(weight_block) < 1
            or len(self.output_partition_sizes) <= 1
        ):
            return

        try:
            block_n = int(weight_block[0])
        except (ValueError, TypeError):
            return

        if block_n <= 0:
            return

        if any(size % block_n != 0 for size in self.output_partition_sizes):
            self.allow_fp8_block_shape_mismatch = True
            logger.debug(
                "Allowing FP8 block shape mismatch for %s (block_n=%d, partitions=%s)",
                getattr(self, "prefix", "<unknown>"),
                block_n,
                self.output_partition_sizes,
            )

528
529
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        output_dim = getattr(param, "output_dim", None)
530

531
532
533
534
535
536
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

537
538
539
540
541
542
543
544
        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
545
546
            final_shape = list(loaded_weight.shape)
            if output_dim is not None:
547
                assert final_shape[output_dim] % self.tp_size == 0
548
                final_shape[output_dim] = final_shape[output_dim] // self.tp_size
549
            param.materialize(final_shape, dtype=loaded_weight.dtype)
550

551
        param_data = param.data
552
        if output_dim is not None and not is_sharded_weight:
553
            shard_size = param_data.shape[output_dim]
554
            start_idx = self.tp_rank * shard_size
555
            loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
556
557
558
559
560

        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)
561

562
563
564
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

565
    def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor):
566
567
568
569
570
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)
571
572
        param.load_column_parallel_weight(loaded_weight=loaded_weight)

573
    def forward(
574
575
        self,
        input_,
576
    ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
577
578
579
        bias = self.bias if not self.skip_bias_add else None

        # Matrix multiply.
580
        assert self.quant_method is not None
581
        output_parallel = self.quant_method.apply(self, input_, bias)
582

583
        if self.gather_output and self.tp_size > 1:
584
585
586
587
            # All-gather across the partitions.
            output = tensor_model_parallel_all_gather(output_parallel)
        else:
            output = output_parallel
588

589
590
        if not self.return_bias:
            return output
591
        output_bias = self.bias if self.skip_bias_add else None
592
593
        return output, output_bias

594
595
596
597
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size_per_partition}"
        s += f", bias={self.bias is not None}"
598
        s += f", tp_size={self.tp_size}"
599
600
601
        s += f", gather_output={self.gather_output}"
        return s

602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620

class MergedColumnParallelLinear(ColumnParallelLinear):
    """Packed linear layers with column parallelism.

    Similar to ColumnParallelLinear, but the weight matrix is concatenated
    along the output dimension. When the weight matrix is loaded, the
    different partitions are sharded separately.

    Args:
        input_size: input dimension of the linear layer.
        output_sizes: list of output dimensions of the linear layer.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make the output
                       available to all GPUs, otherwise, every GPU will have
                       its own output.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
621
        quant_config: Quantization configure.
622
623
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
624
        return_bias: If true, return bias together with outputs in forward pass.
625
626
        disable_tp: If true, all weights matrix won't be sharded, this layer
                    will be treated as a "Replicated" MergedLinear.
627
628
    """

629
630
631
632
633
634
635
    def __init__(
        self,
        input_size: int,
        output_sizes: list[int],
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
636
637
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
638
639
640
        prefix: str = "",
        *,
        return_bias: bool = True,
641
        disable_tp: bool = False,
642
    ):
643
        self.output_sizes = output_sizes
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0

        assert all(output_size % self.tp_size == 0 for output_size in output_sizes)
        super().__init__(
            input_size=input_size,
            output_size=sum(output_sizes),
            bias=bias,
            gather_output=gather_output,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            quant_config=quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
James Fleming's avatar
James Fleming committed
660

661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
    def validate_shard_id(self, loaded_shard_id: int | tuple[int, ...] | None):
        if loaded_shard_id is None:
            return
        if isinstance(loaded_shard_id, tuple):
            for idx in loaded_shard_id:
                if not (0 <= idx < len(self.output_sizes)):
                    raise ValueError(
                        f"Shard id index {idx} should be between 0 and "
                        f"{len(self.output_sizes) - 1}. Got shard id {loaded_shard_id}."
                    )
            if len(loaded_shard_id) > 1 and any(
                b - a != 1 for a, b in zip(loaded_shard_id[:-1], loaded_shard_id[1:])
            ):
                raise ValueError(
                    "Shard id with multiple indices should be consecutive. "
                    f"Got shard id {loaded_shard_id}."
                )
            return
        elif isinstance(loaded_shard_id, int):
            if loaded_shard_id < 0 or loaded_shard_id >= len(self.output_sizes):
                raise ValueError(
                    f"Shard id should be between 0 and {len(self.output_sizes) - 1}. "
                    f"Got shard id {loaded_shard_id}."
                )
            return
        raise ValueError("This line should not be reached")

688
689
690
691
    def weight_loader(
        self,
        param: Parameter,
        loaded_weight: torch.Tensor,
692
        loaded_shard_id: tuple[int, ...] | int | None = None,
693
    ):
694
        self.validate_shard_id(loaded_shard_id)
695
696
697
698
        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
699
700
701
702
703
704
        if isinstance(loaded_shard_id, tuple) and (
            is_gguf_weight or is_gguf_weight_type
        ):
            raise NotImplementedError(
                "Shard id with multiple indices is not supported for GGUF."
            )
705
        if is_gguf_weight_type:
706
707
708
709
710
            if loaded_shard_id is not None:
                param.data[loaded_shard_id].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
                param.shard_weight_type = {
711
                    i: loaded_weight.item() for i, _ in enumerate(self.output_sizes)
712
                }
713
714
            return

715
716
        if is_gguf_weight:
            output_dim = getattr(param, "output_dim", None)
717
718
            shard_size = loaded_weight.size(output_dim) // self.tp_size
            start_idx = self.tp_rank * shard_size
719

720
            if loaded_shard_id is not None:
721
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
722
723
724
725
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
726

727
728
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
729
730
        # Special case for per-tensor scale to load scalar into fused array.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
731

732
        if loaded_shard_id is None or isinstance(loaded_shard_id, tuple):
733
734
            # Loaded weight is already fused on disk (mlp).
            # (e.g., Phi-3's gate_up_proj).
735
            if output_dim is None:
736
                if needs_scalar_to_array:
737
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
738
739
                        param_data, loaded_weight, 0
                    )
740

741
742
743
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
744
745
746
747
748
749

            output_sizes = (
                self.output_sizes[loaded_shard_id[0] : loaded_shard_id[-1] + 1]
                if loaded_shard_id is not None
                else self.output_sizes
            )
750
            current_shard_offset = 0
751
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
752
753
754
755
756
            if (
                use_bitsandbytes_4bit
                and isinstance(loaded_shard_id, tuple)
                and self.tp_size > 1
            ):
757
758
                raise NotImplementedError(
                    "Shard id with multiple indices is not supported "
759
                    "for BNB quantization with TP yet."
760
                )
761
            shard_offsets: list[tuple[int, int, int]] = []
762
            for i, output_size in enumerate(output_sizes):
763
764
765
766
                shard_offsets.append((i, current_shard_offset, output_size))
                current_shard_offset += output_size
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
767
                # Special case for Quantization.
768
769
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
770
771
772
773
774
775
776
                # Add check to adjust the size/offset for FP8 block scales
                if isinstance(param, BlockQuantScaleParameter):
                    weight_block_size = getattr(self, "weight_block_size", None)
                    shard_size, shard_offset = adjust_block_scale_shard(
                        weight_block_size, shard_size, shard_offset
                    )

777
                if packed_dim == output_dim:
778
779
                    shard_size = shard_size // param.packed_factor
                    shard_offset = shard_offset // param.packed_factor
780
                    # Special case for Marlin.
781
                    shard_size, shard_offset = adjust_marlin_shard(
782
783
                        param, shard_size, shard_offset
                    )
784

785
                if use_bitsandbytes_4bit:
786
787
788
789
790
791
792
                    index = list(itertools.accumulate([0] + self.output_sizes))
                    orig_offsets = {
                        str(i): (index[i], size)
                        for i, size in enumerate(self.output_sizes)
                    }
                    orig_offsets["total"] = (self.output_size, 0)
                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
793
794
                        param, orig_offsets, str(shard_id)
                    )
795

796
                loaded_weight_shard = loaded_weight.narrow(
797
798
                    output_dim, shard_offset, shard_size
                )
799
800
801
802
803
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id < len(self.output_sizes)
        if output_dim is not None:
804
805
            shard_offset = sum(self.output_sizes[:loaded_shard_id])
            shard_size = self.output_sizes[loaded_shard_id]
806
807
            shard_offset //= self.tp_size
            shard_size //= self.tp_size
808
809
810
811
812
813
814

            if isinstance(param, BlockQuantScaleParameter):
                weight_block_size = getattr(self, "weight_block_size", None)
                shard_size, shard_offset = adjust_block_scale_shard(
                    weight_block_size, shard_size, shard_offset
                )

815
            # Special case for quantization.
816
817
818
819
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
820
821
                shard_size = shard_size // param.packed_factor
                shard_offset = shard_offset // param.packed_factor
822
                # Special case for Marlin.
823
                shard_size, shard_offset = adjust_marlin_shard(
824
825
                    param, shard_size, shard_offset
                )
826

827
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
828
829
830
831
832
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

833
            if use_bitsandbytes_4bit:
834
835
836
837
838
839
840
841
                index = list(itertools.accumulate([0] + self.output_sizes))
                orig_offsets = {
                    str(i): (index[i], size) for i, size in enumerate(self.output_sizes)
                }
                orig_offsets["total"] = (self.output_size, 0)
                shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
                    param, orig_offsets, str(loaded_shard_id)
                )
842
            param_data = param_data.narrow(output_dim, shard_offset, shard_size)
843
            start_idx = self.tp_rank * shard_size
844
            if not is_sharded_weight:
845
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
846
847
848
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
849
850
                param_data, loaded_weight, loaded_shard_id
            )
851

852
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
853
854
855
856
857
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "MergedColumnParallelLinear, assume the weight is "
858
859
                    "the same for all partitions."
                )
860

861
862
863
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

864
    def _load_fused_module_from_checkpoint(
865
866
867
868
        self,
        param: BasevLLMParameter,
        loaded_weight: torch.Tensor,
        output_sizes: list[int] | None = None,
869
    ):
870
871
872
        """
        Handle special case for models where MLP layers are already
        fused on disk. In this case, we have no shard id. This function
873
        determines the shard id by splitting these layers and then calls
874
875
876
877
878
879
880
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """

        current_shard_offset = 0
881
        shard_offsets: list[tuple[int, int, int]] = []
882
883
        output_sizes = output_sizes or self.output_sizes
        for i, output_size in enumerate(output_sizes):
884
885
886
887
888
889
890
            shard_offsets.append((i, current_shard_offset, output_size))
            current_shard_offset += output_size

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
891
892
893
894
895
896
897
898
899
900
901
            if (
                isinstance(param, (PackedColumnParameter, PackedvLLMParameter))
                and param.packed_dim == param.output_dim
            ):
                shard_size, shard_offset = param.adjust_shard_indexes_for_packing(
                    shard_size=shard_size, shard_offset=shard_offset
                )

            loaded_weight_shard = loaded_weight.narrow(
                param.output_dim, shard_offset, shard_size
            )
902
903
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

904
905
906
907
    def weight_loader_v2(
        self,
        param: BasevLLMParameter,
        loaded_weight: torch.Tensor,
908
        loaded_shard_id: tuple[int, ...] | int | None = None,
909
    ):
910
        self.validate_shard_id(loaded_shard_id)
911
        if loaded_shard_id is None or isinstance(loaded_shard_id, tuple):
912
            if isinstance(param, PerTensorScaleParameter):
913
914
915
916
917
918
919
920
921
                if isinstance(loaded_shard_id, tuple):
                    for idx in loaded_shard_id:
                        param.load_merged_column_weight(
                            loaded_weight=loaded_weight, shard_id=idx
                        )
                else:
                    param.load_merged_column_weight(
                        loaded_weight=loaded_weight, shard_id=0
                    )
922
                return
923
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
924
                param.load_merged_column_weight(loaded_weight=loaded_weight)
925
                return
926
927
928
929
930
931
932
933
934
935
936
            output_sizes = (
                [self.output_sizes[idx] for idx in loaded_shard_id]
                if loaded_shard_id
                else None
            )
            if isinstance(param, BlockQuantScaleParameter):
                weight_block_size = getattr(self, "weight_block_size", None)
                output_sizes = [
                    adjust_block_scale_shard(weight_block_size, size, 0)[0]
                    for size in (output_sizes or self.output_sizes)
                ]
937
            # TODO: @dsikka - move to parameter.py
938
939
940
            self._load_fused_module_from_checkpoint(
                param, loaded_weight, output_sizes=output_sizes
            )
941
942
943
944
            return

        assert loaded_shard_id < len(self.output_sizes)

945
946
        shard_offset = sum(self.output_sizes[:loaded_shard_id])
        shard_size = self.output_sizes[loaded_shard_id]
947
948
        shard_offset //= self.tp_size
        shard_size //= self.tp_size
949

950
        if isinstance(param, BlockQuantScaleParameter):
951
952
953
            weight_block_size = getattr(self, "weight_block_size", None)
            shard_size, shard_offset = adjust_block_scale_shard(
                weight_block_size, shard_size, shard_offset
954
            )
955

956
957
958
959
960
961
962
        param.load_merged_column_weight(
            loaded_weight=loaded_weight,
            shard_id=loaded_shard_id,
            shard_offset=shard_offset,
            shard_size=shard_size,
            tp_rank=self.tp_rank,
        )
963

964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985

class QKVParallelLinear(ColumnParallelLinear):
    """Linear layers for the attention's QKV transformation.

    Linear layers for the linear transformation of the query, key, and value
    vectors in the attention layer. The weight matrix is concatenated along
    the output dimension. The layer is parallelized along the head dimension.
    When the number of key/value heads is smaller than the number of query
    heads (e.g., multi-query/grouped-query attention), the key/value head may
    be replicated while the query heads are partitioned.

    Args:
        hidden_size: input hidden state size of the transformer.
        head_size: size of each attention head.
        total_num_heads: total number of attention query heads.
        total_num_kv_heads: total number of attention key/value heads. If
                            None, assume total_num_kv_heads = total_num_heads.
        bias: If true, add bias.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
986
        quant_config: Quantization configure.
987
988
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
989
        return_bias: If true, return bias together with outputs in forward pass.
990
        disable_tp: If true, weights matrix won't be sharded through tp rank.
991
992
    """

993
994
995
996
997
    def __init__(
        self,
        hidden_size: int,
        head_size: int,
        total_num_heads: int,
998
        total_num_kv_heads: int | None = None,
999
1000
        bias: bool = True,
        skip_bias_add: bool = False,
1001
1002
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
1003
1004
1005
        prefix: str = "",
        *,
        return_bias: bool = True,
1006
        disable_tp: bool = False,
1007
        v_head_size: int | None = None,
1008
    ):
1009
1010
        self.hidden_size = hidden_size
        self.head_size = head_size
1011
        self.v_head_size = v_head_size if v_head_size is not None else head_size
1012
1013
1014
1015
1016
        self.total_num_heads = total_num_heads
        if total_num_kv_heads is None:
            total_num_kv_heads = total_num_heads
        self.total_num_kv_heads = total_num_kv_heads
        # Divide the weight matrix along the last dimension.
1017
        tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
1018
1019
1020
        self.num_heads = divide(self.total_num_heads, tp_size)
        if tp_size >= self.total_num_kv_heads:
            self.num_kv_heads = 1
1021
            self.num_kv_head_replicas = divide(tp_size, self.total_num_kv_heads)
1022
1023
1024
1025
        else:
            self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
            self.num_kv_head_replicas = 1
        input_size = self.hidden_size
1026
        output_size = (
1027
1028
1029
1030
            self.num_heads * self.head_size
            + self.num_kv_heads * self.head_size
            + self.num_kv_heads * self.v_head_size
        ) * tp_size
1031
1032
1033
        self.output_sizes = [
            self.num_heads * self.head_size * tp_size,  # q_proj
            self.num_kv_heads * self.head_size * tp_size,  # k_proj
1034
            self.num_kv_heads * self.v_head_size * tp_size,  # v_proj
James Fleming's avatar
James Fleming committed
1035
1036
        ]

1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
        super().__init__(
            input_size=input_size,
            output_size=output_size,
            bias=bias,
            gather_output=False,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            quant_config=quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
1049

1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
    def validate_shard_id(self, loaded_shard_id: str | None):
        if loaded_shard_id is None:
            return
        if isinstance(loaded_shard_id, str):
            if loaded_shard_id not in ["q", "k", "v"]:
                raise ValueError(
                    "Shard id for QKVParallelLinear should be 'q', 'k', or 'v', "
                    f"got shard id {loaded_shard_id}."
                )
            return
        raise ValueError("This line should not be reached")

1062
1063
1064
1065
1066
    def _get_shard_offset_mapping(self, loaded_shard_id: str):
        shard_offset_mapping = {
            "q": 0,
            "k": self.num_heads * self.head_size,
            "v": (self.num_heads + self.num_kv_heads) * self.head_size,
1067
1068
            "total": (self.num_heads + self.num_kv_heads) * self.head_size
            + self.num_kv_heads * self.v_head_size,
1069
1070
1071
1072
1073
1074
1075
        }
        return shard_offset_mapping.get(loaded_shard_id)

    def _get_shard_size_mapping(self, loaded_shard_id: str):
        shard_size_mapping = {
            "q": self.num_heads * self.head_size,
            "k": self.num_kv_heads * self.head_size,
1076
            "v": self.num_kv_heads * self.v_head_size,
1077
1078
1079
        }
        return shard_size_mapping.get(loaded_shard_id)

1080
1081
1082
    def _load_fused_module_from_checkpoint(
        self, param: BasevLLMParameter, loaded_weight: torch.Tensor
    ):
1083
        """
1084
        Handle special case for models where QKV layers are already
1085
        fused on disk. In this case, we have no shard id. This function
1086
        determines the shard id by splitting these layers and then calls
1087
1088
1089
1090
1091
1092
1093
1094
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """
        shard_offsets = [
            # (shard_id, shard_offset, shard_size)
            ("q", 0, self.total_num_heads * self.head_size),
1095
1096
1097
1098
1099
1100
1101
1102
            (
                "k",
                self.total_num_heads * self.head_size,
                self.total_num_kv_heads * self.head_size,
            ),
            (
                "v",
                (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
1103
                self.total_num_kv_heads * self.v_head_size,
1104
            ),
1105
1106
1107
1108
1109
1110
        ]

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
            if (
                isinstance(param, (PackedColumnParameter, PackedvLLMParameter))
                and param.packed_dim == param.output_dim
            ):
                shard_size, shard_offset = param.adjust_shard_indexes_for_packing(
                    shard_size=shard_size, shard_offset=shard_offset
                )

            loaded_weight_shard = loaded_weight.narrow(
                param.output_dim, shard_offset, shard_size
            )
1122
1123
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

1124
1125
1126
1127
    def weight_loader_v2(
        self,
        param: BasevLLMParameter,
        loaded_weight: torch.Tensor,
1128
        loaded_shard_id: str | None = None,
1129
    ):
1130
        self.validate_shard_id(loaded_shard_id)
1131
        if loaded_shard_id is None:  # special case for certain models
1132
            if isinstance(param, PerTensorScaleParameter):
1133
1134
1135
                param.load_qkv_weight(
                    loaded_weight=loaded_weight, shard_id=0, tp_rank=self.tp_rank
                )
1136
                return
1137
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
1138
                param.load_qkv_weight(loaded_weight=loaded_weight, tp_rank=self.tp_rank)
1139
                return
1140
            # TODO: @dsikka - move to parameter.py
1141
1142
1143
1144
1145
1146
1147
1148
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id in ["q", "k", "v"]

        shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
        shard_size = self._get_shard_size_mapping(loaded_shard_id)

1149
        if isinstance(param, BlockQuantScaleParameter):
1150
1151
1152
1153
            weight_block_size = getattr(self, "weight_block_size", None)
            shard_size, shard_offset = adjust_block_scale_shard(
                weight_block_size, shard_size, shard_offset
            )
1154

1155
1156
1157
1158
1159
1160
1161
1162
        param.load_qkv_weight(
            loaded_weight=loaded_weight,
            num_heads=self.num_kv_head_replicas,
            shard_id=loaded_shard_id,
            shard_offset=shard_offset,
            shard_size=shard_size,
            tp_rank=self.tp_rank,
        )
1163

1164
1165
1166
1167
    def weight_loader(
        self,
        param: Parameter,
        loaded_weight: torch.Tensor,
1168
        loaded_shard_id: str | None = None,
1169
    ):
1170
        self.validate_shard_id(loaded_shard_id)
1171
1172
1173
1174
        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
1175
        if is_gguf_weight_type:
1176
            idx_map = {"q": 0, "k": 1, "v": 2}
1177
1178
1179
1180
            if loaded_shard_id is not None:
                param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
1181
                param.shard_weight_type = {k: loaded_weight.item() for k in idx_map}
1182
1183
            return

1184
1185
        if is_gguf_weight:
            output_dim = getattr(param, "output_dim", None)
1186
1187
            shard_size = loaded_weight.size(output_dim) // self.tp_size
            start_idx = self.tp_rank * shard_size
1188

1189
            if loaded_shard_id is not None:
1190
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
1191
1192
1193
1194
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
1195

1196
1197
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
1198

1199
1200
        # Special case for per-tensor scales in fused case.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
1201

1202
        if loaded_shard_id is None:
1203
1204
            # Loaded weight is already fused on disk (qkv).
            # (e.g., Phi-3's qkv_proj).
1205
            if output_dim is None:
1206
                if needs_scalar_to_array:
1207
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
1208
1209
                        param_data, loaded_weight, 0
                    )
1210

1211
1212
1213
1214
1215
1216
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            shard_offsets = [
                # (shard_id, shard_offset, shard_size)
                ("q", 0, self.total_num_heads * self.head_size),
1217
1218
1219
1220
1221
1222
1223
1224
                (
                    "k",
                    self.total_num_heads * self.head_size,
                    self.total_num_kv_heads * self.head_size,
                ),
                (
                    "v",
                    (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
1225
                    self.total_num_kv_heads * self.v_head_size,
1226
                ),
1227
            ]
1228
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1229

1230
1231
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
1232
                # Special case for Quantized Weights.
1233
1234
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
1235
1236
1237
1238
1239
1240
1241
                # Add check to adjust the size/offset for FP8 block scales
                if isinstance(param, BlockQuantScaleParameter):
                    weight_block_size = getattr(self, "weight_block_size", None)
                    shard_size, shard_offset = adjust_block_scale_shard(
                        weight_block_size, shard_size, shard_offset
                    )

1242
                if packed_dim == output_dim:
1243
1244
                    shard_size = shard_size // param.packed_factor
                    shard_offset = shard_offset // param.packed_factor
1245

1246
                    # Special case for Marlin.
1247
                    shard_size, shard_offset = adjust_marlin_shard(
1248
1249
                        param, shard_size, shard_offset
                    )
1250

1251
1252
1253
                if use_bitsandbytes_4bit:
                    orig_qkv_offsets = {
                        "q": (0, self.total_num_heads * self.head_size),
1254
1255
1256
1257
1258
1259
1260
                        "k": (
                            self.total_num_heads * self.head_size,
                            self.total_num_kv_heads * self.head_size,
                        ),
                        "v": (
                            (self.total_num_heads + self.total_num_kv_heads)
                            * self.head_size,
1261
                            self.total_num_kv_heads * self.v_head_size,
1262
1263
                        ),
                        "total": (
1264
1265
1266
                            (self.total_num_heads + self.total_num_kv_heads)
                            * self.head_size
                            + self.total_num_kv_heads * self.v_head_size,
1267
1268
                            0,
                        ),
1269
1270
1271
                    }

                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
1272
1273
                        param, orig_qkv_offsets, shard_id
                    )
1274

1275
                loaded_weight_shard = loaded_weight.narrow(
1276
1277
                    output_dim, shard_offset, shard_size
                )
1278
1279
1280
1281
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id in ["q", "k", "v"]
1282
1283

        # If output dim is defined, use the default loading process.
1284
1285
1286
1287
1288
1289
1290
1291
        if output_dim is not None:
            if loaded_shard_id == "q":
                shard_offset = 0
                shard_size = self.num_heads * self.head_size
            elif loaded_shard_id == "k":
                shard_offset = self.num_heads * self.head_size
                shard_size = self.num_kv_heads * self.head_size
            elif loaded_shard_id == "v":
1292
                shard_offset = (self.num_heads + self.num_kv_heads) * self.head_size
1293
                shard_size = self.num_kv_heads * self.v_head_size
1294
1295
1296
1297
1298
1299
1300

            if isinstance(param, BlockQuantScaleParameter):
                weight_block_size = getattr(self, "weight_block_size", None)
                shard_size, shard_offset = adjust_block_scale_shard(
                    weight_block_size, shard_size, shard_offset
                )

1301
            # Special case for Quantized Weights.
1302
1303
1304
1305
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
1306
1307
                shard_size = shard_size // param.packed_factor
                shard_offset = shard_offset // param.packed_factor
1308

1309
                # Special case for Marlin.
1310
                shard_size, shard_offset = adjust_marlin_shard(
1311
1312
                    param, shard_size, shard_offset
                )
1313

1314
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1315
1316
1317
1318
1319
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

1320
            if use_bitsandbytes_4bit:
1321
1322
                orig_qkv_offsets = {
                    "q": (0, self.num_heads * self.head_size),
1323
1324
1325
1326
1327
1328
                    "k": (
                        self.num_heads * self.head_size,
                        self.num_kv_heads * self.head_size,
                    ),
                    "v": (
                        (self.num_heads + self.num_kv_heads) * self.head_size,
1329
                        self.num_kv_heads * self.v_head_size,
1330
1331
                    ),
                    "total": (
1332
1333
                        (self.num_heads + self.num_kv_heads) * self.head_size
                        + self.num_kv_heads * self.v_head_size,
1334
1335
                        0,
                    ),
1336
                }
1337
                shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
1338
1339
                    param, orig_qkv_offsets, loaded_shard_id
                )
1340

1341
            param_data = param_data.narrow(output_dim, shard_offset, shard_size)
1342
            if loaded_shard_id == "q":
1343
                shard_rank = self.tp_rank
1344
            else:
1345
1346
                shard_rank = self.tp_rank // self.num_kv_head_replicas
            start_idx = shard_rank * shard_size
1347

1348
            if not is_sharded_weight:
1349
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
1350

1351
1352
1353
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
1354
1355
                param_data, loaded_weight, loaded_shard_id
            )
1356
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
1357
1358
1359
1360
1361
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "QKVParallelLinear, assume the weight is the same "
1362
1363
                    "for all partitions."
                )
1364

1365
1366
1367
1368
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)


1369
# --8<-- [start:row_parallel_linear]
1370
@PluggableLayer.register("row_parallel_linear")
1371
class RowParallelLinear(LinearBase):
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
    """Linear layer with row parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its first dimension and X along its second dimension as:
               -   -
              | A_1 |
              | .   |
          A = | .   |        X = [X_1, ..., X_p]
              | .   |
              | A_p |
               -   -
    Arguments:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias. Note that bias is not parallelized.
        input_is_parallel: If true, we assume that the input is already
                           split across the GPUs and we do not split
                           again.
        skip_bias_add: This was added to enable performance optimization where
                       bias can be fused with other element-wise operations.
                       We skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
1394
1395
1396
        reduce_results: If true, call all-reduce on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y = X_iA_i
1397
        quant_config: Quantization configure.
1398
1399
1400
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.down_proj)
        return_bias: If true, return bias together with outputs in forward pass.
1401
        disable_tp: If true, weights matrix won't be sharded through tp rank.
1402
1403
    """

1404
1405
    # --8<-- [end:row_parallel_linear]

1406
1407
1408
1409
1410
1411
1412
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        input_is_parallel: bool = True,
        skip_bias_add: bool = False,
1413
        params_dtype: torch.dtype | None = None,
1414
        reduce_results: bool = True,
1415
        quant_config: QuantizationConfig | None = None,
1416
1417
1418
        prefix: str = "",
        *,
        return_bias: bool = True,
1419
        disable_tp: bool = False,
1420
    ):
1421
        # Divide the weight matrix along the first dimension.
1422
1423
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
1424
1425
1426
1427
        self.input_size_per_partition = divide(input_size, self.tp_size)
        self.output_size_per_partition = output_size
        self.output_partition_sizes = [output_size]

1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
        super().__init__(
            input_size,
            output_size,
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
1438

1439
1440
1441
        self.input_is_parallel = input_is_parallel
        self.reduce_results = reduce_results

1442
        assert self.quant_method is not None
1443
1444
1445
        self.quant_method.create_weights(
            layer=self,
            input_size_per_partition=self.input_size_per_partition,
1446
            output_partition_sizes=self.output_partition_sizes,
1447
1448
1449
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
1450
            weight_loader=(
1451
1452
1453
1454
1455
                self.weight_loader_v2
                if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
                else self.weight_loader
            ),
        )
1456
        if not reduce_results and (bias and not skip_bias_add):
1457
1458
1459
1460
            raise ValueError(
                "When not reduce the results, adding bias to the "
                "results can lead to incorrect results"
            )
1461
1462

        if bias:
1463
1464
1465
1466
1467
1468
1469
1470
            self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype))
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
1471
1472
        else:
            self.register_parameter("bias", None)
1473
        self.update_param_tp_status()
1474
1475
1476

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        input_dim = getattr(param, "input_dim", None)
1477
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1478
1479
1480
1481
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
1482
1483
1484
1485
1486
1487
1488
1489
1490

        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
1491
1492
            weight_shape = list(loaded_weight.shape)
            if input_dim:
1493
                weight_shape[input_dim] = weight_shape[input_dim] // self.tp_size
1494
            param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
1495

1496
        param_data = param.data
1497
        if input_dim is not None and not is_sharded_weight:
1498
            shard_size = param_data.shape[input_dim]
1499
            start_idx = self.tp_rank * shard_size
1500
            loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)
1501

1502
1503
1504
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
1505
1506
            loaded_weight = loaded_weight.reshape(1)

1507
1508
1509
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

1510
    def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor):
1511
1512
1513
1514
1515
1516
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)

1517
1518
        param.load_row_parallel_weight(loaded_weight=loaded_weight)

1519
    def forward(
1520
1521
        self,
        input_,
1522
    ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
1523
1524
1525
        if self.input_is_parallel:
            input_parallel = input_
        else:
Jiayi Yan's avatar
Jiayi Yan committed
1526
            split_input = split_tensor_along_last_dim(
1527
1528
                input_, num_partitions=self.tp_size
            )
Jiayi Yan's avatar
Jiayi Yan committed
1529
            input_parallel = split_input[self.tp_rank].contiguous()
1530
1531

        # Matrix multiply.
1532
        assert self.quant_method is not None
1533
1534
1535
        # Only fuse bias add into GEMM for rank 0 (this ensures that
        # bias will not get added more than once in TP>1 case)
        bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
1536
1537
        output_parallel = self.quant_method.apply(self, input_parallel, bias_)

1538
        if self.reduce_results and self.tp_size > 1:
1539
            output = tensor_model_parallel_all_reduce(output_parallel)
1540
        else:
1541
1542
            output = output_parallel

1543
1544
        if not self.return_bias:
            return output
1545
        output_bias = self.bias if self.skip_bias_add else None
1546
        return output, output_bias
1547
1548

    def extra_repr(self) -> str:
1549
        s = f"in_features={self.input_size_per_partition}"
1550
1551
1552
1553
1554
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={self.tp_size}"
        s += f", reduce_results={self.reduce_results}"
        return s