linear.py 59.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
from abc import abstractmethod
6
7

import torch
8
from torch.nn.parameter import Parameter, UninitializedParameter
9

10
11
12
13
14
15
16
17
from vllm.distributed import (
    divide,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
    split_tensor_along_last_dim,
    tensor_model_parallel_all_gather,
    tensor_model_parallel_all_reduce,
)
18
from vllm.logger import init_logger
19
from vllm.model_executor.custom_op import PluggableLayer
20
21
22
23
from vllm.model_executor.layers.batch_invariant import (
    linear_batch_invariant,
    vllm_is_batch_invariant,
)
24
from vllm.model_executor.layers.quantization.base_config import (
25
26
27
    QuantizationConfig,
    QuantizeMethodBase,
)
28
29
30
from vllm.model_executor.layers.utils import (
    dispatch_unquantized_gemm,
)
31
32
33
34
35
36
37
38
39
from vllm.model_executor.parameter import (
    BasevLLMParameter,
    BlockQuantScaleParameter,
    ModelWeightParameter,
    PackedColumnParameter,
    PackedvLLMParameter,
    PerTensorScaleParameter,
    RowvLLMParameter,
)
40
from vllm.model_executor.utils import set_weight_attrs
41
from vllm.platforms import current_platform
42
from vllm.model_executor.custom_op import CustomOp
43
44
45

logger = init_logger(__name__)

46
WEIGHT_LOADER_V2_SUPPORTED = [
47
    "UnquantizedLinearMethod",
48
    "CompressedTensorsLinearMethod",
49
    "CompressedTensorsLinearTransformMethod",
50
51
52
53
54
55
56
57
58
59
    "AWQMarlinLinearMethod",
    "AWQLinearMethod",
    "GPTQMarlinLinearMethod",
    "Fp8LinearMethod",
    "MarlinLinearMethod",
    "GPTQMarlin24LinearMethod",
    "TPUInt8LinearMethod",
    "GPTQLinearMethod",
    "FBGEMMFp8LinearMethod",
    "ModelOptFp8LinearMethod",
60
61
    "ModelOptFp8PcPtLinearMethod",
    "ModelOptFp8PbWoLinearMethod",
62
63
    "QuarkLinearMethod",
    "ModelOptNvFp4LinearMethod",
64
    "PetitNvFp4LinearMethod",
65
]
66

67

68
69
70
71
72
73
def register_weight_loader_v2_supported_method(cls):
    """Decorator to register a LinearMethod as supporting weight_loader_v2."""
    WEIGHT_LOADER_V2_SUPPORTED.append(cls.__name__)
    return cls


74
75
76
77
78
79
def adjust_marlin_shard(
    param: Parameter,
    shard_size: int,
    shard_offset: int,
) -> tuple[int, int]:
    marlin_tile_size: int | None = getattr(param, "marlin_tile_size", None)
80
81
82
83
84
85
    if marlin_tile_size is None:
        return shard_size, shard_offset

    return shard_size * marlin_tile_size, shard_offset * marlin_tile_size


86
87
88
89
90
def adjust_block_scale_shard(
    weight_block_size: tuple[int, ...] | None,
    shard_size: int,
    shard_offset: int,
) -> tuple[int, int]:
91
92
93
94
95
96
97
    assert weight_block_size is not None
    block_n = weight_block_size[0]
    shard_offset = (shard_offset + block_n - 1) // block_n
    shard_size = (shard_size + block_n - 1) // block_n
    return shard_size, shard_offset


98
def adjust_bitsandbytes_4bit_shard(
99
100
101
    param: Parameter,
    shard_offsets: dict[str, tuple[int, int]],
    loaded_shard_id: str,
102
) -> tuple[int, int]:
103
104
    """Adjust the quantization offsets and sizes for BitsAndBytes sharding."""

105
106
    total, _ = shard_offsets["total"]
    orig_offset, orig_size = shard_offsets[loaded_shard_id]
107
108
109
110
111
112
113
114

    quantized_total = param.data.shape[0]
    quantized_offset = orig_offset * quantized_total // total
    quantized_size = orig_size * quantized_total // total

    return quantized_size, quantized_offset


115
116
117
118
119
def adjust_scalar_to_fused_array(
    param_data: torch.Tensor,
    loaded_weight: torch.Tensor,
    shard_id: int | str,
) -> tuple[torch.Tensor, torch.Tensor]:
120
121
    """For fused modules (QKV and MLP) we have an array of length
    N that holds 1 scale for each "logical" matrix. So the param
122
123
    is an array of length N. The loaded_weight corresponds to
    one of the shards on disk. Here, we slice the param based on
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
    the shard_id for loading.
    """
    qkv_idxs = {"q": 0, "k": 1, "v": 2}

    if isinstance(shard_id, str):
        shard_id = qkv_idxs[shard_id]
    elif not isinstance(shard_id, int):
        raise ValueError(f"Unknown Shard Id {shard_id}")

    # AutoFP8 scales do not have a shape
    # compressed-tensors scales do have a shape
    if len(loaded_weight.shape) != 0:
        assert loaded_weight.shape[0] == 1
        loaded_weight = loaded_weight[0]

139
    return param_data[shard_id], loaded_weight
140
141


142
class LinearMethodBase(QuantizeMethodBase):
143
144
145
    """Base class for different (maybe quantized) linear methods."""

    @abstractmethod
146
147
148
149
150
151
152
153
154
155
156
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        """Create weights for a linear layer.
157
           The weights will be set as attributes of the layer.
158

159
160
161
        Args:
            layer: The layer that is using the LinearMethodBase factory.
            input_size_per_partition: Size of the weight input dim on rank X.
162
            output_partition_sizes: Sizes of the output dim of each logical
163
164
165
166
167
168
                weight on rank X. E.g., output_partition_sizes for QKVLinear
                is a list contains the width of Wq, Wk, Wv on rank X.
            input_size: Size of the input dim of the weight across all ranks.
            output_size: Size of the output dim of the weight across all ranks.
            params_dtype: Datatype of the parameters.
        """
169
170
171
        raise NotImplementedError

    @abstractmethod
172
173
174
175
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
176
        bias: torch.Tensor | None = None,
177
    ) -> torch.Tensor:
178
179
        """Apply the weights in layer to the input tensor.
        Expects create_weights to have been called before on the layer."""
180
181
        raise NotImplementedError

182
183
@CustomOp.register("unquantized_linear_method")
class UnquantizedLinearMethod(LinearMethodBase, CustomOp):
184
    """Linear method without quantization."""
185

186
187
188
189
190
191
192
193
194
195
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
196
197
198
199
        # This method creates unquantized linear weights.
        # The weights are not quantized, and they are not sharded.
        # The amount of memory allocated for the weights is
        # sum(output_partition_sizes) * input_size_per_partition.
200
201
202
203
204
205
206
207
208
209
210
        weight_loader = extra_weight_attrs.pop("weight_loader")
        weight = ModelWeightParameter(
            data=torch.empty(
                sum(output_partition_sizes),
                input_size_per_partition,
                dtype=params_dtype,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
211

212
213
        layer.register_parameter("weight", weight)
        set_weight_attrs(weight, extra_weight_attrs)
214

215
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
216
        if current_platform.is_cpu():
217
            from vllm.model_executor.layers.utils import dispatch_cpu_unquantized_gemm
218

219
            dispatch_cpu_unquantized_gemm(layer, remove_weight=True)
220

221
222
223
224
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
225
        bias: torch.Tensor | None = None,
226
    ) -> torch.Tensor:
227
        if vllm_is_batch_invariant() and current_platform.is_cuda_alike():
228
            return linear_batch_invariant(x, layer.weight, bias)
229
        return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
230
231


232
class LinearBase(PluggableLayer):
233
    """Base linear layer.
234
235
236
237
238
239

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
240
        quant_config: Quantization configure.
241
        prefix: Prefix for parameter names.
242
        return_bias: If true, return bias together with outputs in forward pass.
243
        disable_tp: If true, tensor parallelism will be disabled for this layer.
244
245
246
247
248
249
250
    """

    def __init__(
        self,
        input_size: int,
        output_size: int,
        skip_bias_add: bool = False,
251
252
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
253
        prefix: str = "",
254
255
        *,
        return_bias: bool = True,
256
        disable_tp: bool = False,
257
258
259
260
261
262
263
264
265
266
    ):
        super().__init__()

        # Keep input parameters
        self.input_size = input_size
        self.output_size = output_size
        self.skip_bias_add = skip_bias_add
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype
267
268
        self.quant_config = quant_config
        self.prefix = prefix
269
        self.allow_fp8_block_shape_mismatch = False
270
        if quant_config is None:
271
            self.quant_method: QuantizeMethodBase | None = UnquantizedLinearMethod()
272
        else:
273
            self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
274
        self.return_bias = return_bias
275
        self.disable_tp = disable_tp
276
277
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
278

279
    def update_param_tp_status(self):
280
281
282
283
        for param in self.parameters():
            if isinstance(param, BasevLLMParameter):
                param.tp_rank = self.tp_rank
                param.tp_size = self.tp_size
284
285


286
# --8<-- [start:replicated_linear]
287
@PluggableLayer.register("replicated_linear")
288
289
290
291
292
293
294
295
296
297
class ReplicatedLinear(LinearBase):
    """Replicated linear layer.

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        bias: If true, add bias.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
        quant_config: Quantization configure.
298
299
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
300
        return_bias: If true, return bias together with outputs in forward pass.
301
        disable_tp: Take no effect for replicated linear layers.
302
303
    """

304
305
    # --8<-- [end:replicated_linear]

306
307
308
309
310
311
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        skip_bias_add: bool = False,
312
313
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
314
315
316
        prefix: str = "",
        *,
        return_bias: bool = True,
317
        disable_tp: bool = False,
318
    ):
319
320
321
322
323
324
        # If MergedReplicatedLinear, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = self.output_sizes
        else:
            self.output_partition_sizes = [output_size]

325
326
327
328
329
330
331
332
333
334
        super().__init__(
            input_size,
            output_size,
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
335

336
337
        # All the linear layer supports quant method.
        assert self.quant_method is not None
338
339
340
341
342
343
344
345
346
        self.quant_method.create_weights(
            self,
            self.input_size,
            self.output_partition_sizes,
            self.input_size,
            self.output_size,
            self.params_dtype,
            weight_loader=self.weight_loader,
        )
347

348
349
        if bias:
            self.bias = Parameter(
350
351
352
353
354
355
356
357
358
                torch.empty(self.output_size, dtype=self.params_dtype)
            )
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
359
360
361
        else:
            self.register_parameter("bias", None)

362
363
364
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        # If the weight on disk does not have a shape, give it one
        # (such scales for AutoFp8).
365
366
367
368
369
370
371
372
373
374
375
        # Special case for GGUF

        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
            param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)

376
377
378
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)

379
380
        assert param.size() == loaded_weight.size(), (
            f"Tried to load weights of size {loaded_weight.size()}"
381
382
            f"to a parameter of size {param.size()}"
        )
383
384
        param.data.copy_(loaded_weight)

385
    def forward(
386
387
        self,
        x: torch.Tensor,
388
    ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
389
        bias = self.bias if not self.skip_bias_add else None
390
        assert self.quant_method is not None
391

392
        output = self.quant_method.apply(self, x, bias)
393

394
395
        if not self.return_bias:
            return output
396
        output_bias = self.bias if self.skip_bias_add else None
397
398
        return output, output_bias

399
400
401
402
403
404
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        return s

405

406
# --8<-- [start:column_parallel_linear]
407
@PluggableLayer.register("column_parallel_linear")
408
class ColumnParallelLinear(LinearBase):
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
    """Linear layer with column parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its second dimension as A = [A_1, ..., A_p].

    Args:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y_i = XA_i
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
425
        quant_config: Quantization configure.
426
        prefix: The name of the layer in the state dict, including all parents
427
                        (e.g. model.layers.0.qkv_proj)
428
429
        return_bias: If true, return bias together with outputs in forward pass.
        disable_tp: If true, weights matrix won't be sharded through tp rank.
430
431
    """

432
433
    # --8<-- [end:column_parallel_linear]

434
435
436
437
438
439
440
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
441
442
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
443
444
445
        prefix: str = "",
        *,
        return_bias: bool = True,
446
        disable_tp: bool = False,
447
    ):
448
        # Divide the weight matrix along the last dimension.
449
450
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
451
452
        self.input_size_per_partition = input_size
        self.output_size_per_partition = divide(output_size, self.tp_size)
453
454
455
456
        self.output_partition_sizes = [self.output_size_per_partition]
        # If QKV or MergedColumn, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = [
457
                divide(output_size, self.tp_size) for output_size in self.output_sizes
458
459
            ]

460
461
462
463
464
465
466
467
468
469
        super().__init__(
            input_size,
            output_size,
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
470

471
        self._maybe_allow_fp8_block_shape_mismatch()
472
473
474
        self.gather_output = gather_output

        assert self.quant_method is not None
475
476
        self.quant_method.create_weights(
            layer=self,
477
            input_size_per_partition=self.input_size_per_partition,
478
479
480
481
            output_partition_sizes=self.output_partition_sizes,
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
482
            weight_loader=(
483
484
485
486
487
                self.weight_loader_v2
                if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
                else self.weight_loader
            ),
        )
488
489
        if bias:
            self.bias = Parameter(
490
491
492
493
494
495
496
497
498
                torch.empty(self.output_size_per_partition, dtype=params_dtype)
            )
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
499
500
        else:
            self.register_parameter("bias", None)
501
        self.update_param_tp_status()
502

503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
    def _maybe_allow_fp8_block_shape_mismatch(self) -> None:
        quant_config = getattr(self, "quant_config", None)
        weight_block = getattr(quant_config, "weight_block_size", None)
        if (
            weight_block is None
            or len(weight_block) < 1
            or len(self.output_partition_sizes) <= 1
        ):
            return

        try:
            block_n = int(weight_block[0])
        except (ValueError, TypeError):
            return

        if block_n <= 0:
            return

        if any(size % block_n != 0 for size in self.output_partition_sizes):
            self.allow_fp8_block_shape_mismatch = True
            logger.debug(
                "Allowing FP8 block shape mismatch for %s (block_n=%d, partitions=%s)",
                getattr(self, "prefix", "<unknown>"),
                block_n,
                self.output_partition_sizes,
            )

530
531
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        output_dim = getattr(param, "output_dim", None)
532

533
534
535
536
537
538
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

539
540
541
542
543
544
545
546
        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
547
548
            final_shape = list(loaded_weight.shape)
            if output_dim is not None:
549
                assert final_shape[output_dim] % self.tp_size == 0
550
                final_shape[output_dim] = final_shape[output_dim] // self.tp_size
551
            param.materialize(final_shape, dtype=loaded_weight.dtype)
552

553
        param_data = param.data
554
        if output_dim is not None and not is_sharded_weight:
555
            shard_size = param_data.shape[output_dim]
556
            start_idx = self.tp_rank * shard_size
557
            loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
558
559
560
561
562

        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)
563

564
565
566
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

567
    def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor):
568
569
570
571
572
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)
573
574
        param.load_column_parallel_weight(loaded_weight=loaded_weight)

575
    def forward(
576
577
        self,
        input_,
578
    ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
579
580
581
        bias = self.bias if not self.skip_bias_add else None

        # Matrix multiply.
582
        assert self.quant_method is not None
583
        output_parallel = self.quant_method.apply(self, input_, bias)
584

585
        if self.gather_output and self.tp_size > 1:
586
587
588
589
            # All-gather across the partitions.
            output = tensor_model_parallel_all_gather(output_parallel)
        else:
            output = output_parallel
590

591
592
        if not self.return_bias:
            return output
593
        output_bias = self.bias if self.skip_bias_add else None
594
595
        return output, output_bias

596
597
598
599
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size_per_partition}"
        s += f", bias={self.bias is not None}"
600
        s += f", tp_size={self.tp_size}"
601
602
603
        s += f", gather_output={self.gather_output}"
        return s

604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622

class MergedColumnParallelLinear(ColumnParallelLinear):
    """Packed linear layers with column parallelism.

    Similar to ColumnParallelLinear, but the weight matrix is concatenated
    along the output dimension. When the weight matrix is loaded, the
    different partitions are sharded separately.

    Args:
        input_size: input dimension of the linear layer.
        output_sizes: list of output dimensions of the linear layer.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make the output
                       available to all GPUs, otherwise, every GPU will have
                       its own output.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
623
        quant_config: Quantization configure.
624
625
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
626
        return_bias: If true, return bias together with outputs in forward pass.
627
628
        disable_tp: If true, all weights matrix won't be sharded, this layer
                    will be treated as a "Replicated" MergedLinear.
629
630
    """

631
632
633
634
635
636
637
    def __init__(
        self,
        input_size: int,
        output_sizes: list[int],
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
638
639
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
640
641
642
        prefix: str = "",
        *,
        return_bias: bool = True,
643
        disable_tp: bool = False,
644
    ):
645
        self.output_sizes = output_sizes
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0

        assert all(output_size % self.tp_size == 0 for output_size in output_sizes)
        super().__init__(
            input_size=input_size,
            output_size=sum(output_sizes),
            bias=bias,
            gather_output=gather_output,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            quant_config=quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
James Fleming's avatar
James Fleming committed
662

663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
    def validate_shard_id(self, loaded_shard_id: int | tuple[int, ...] | None):
        if loaded_shard_id is None:
            return
        if isinstance(loaded_shard_id, tuple):
            for idx in loaded_shard_id:
                if not (0 <= idx < len(self.output_sizes)):
                    raise ValueError(
                        f"Shard id index {idx} should be between 0 and "
                        f"{len(self.output_sizes) - 1}. Got shard id {loaded_shard_id}."
                    )
            if len(loaded_shard_id) > 1 and any(
                b - a != 1 for a, b in zip(loaded_shard_id[:-1], loaded_shard_id[1:])
            ):
                raise ValueError(
                    "Shard id with multiple indices should be consecutive. "
                    f"Got shard id {loaded_shard_id}."
                )
            return
        elif isinstance(loaded_shard_id, int):
            if loaded_shard_id < 0 or loaded_shard_id >= len(self.output_sizes):
                raise ValueError(
                    f"Shard id should be between 0 and {len(self.output_sizes) - 1}. "
                    f"Got shard id {loaded_shard_id}."
                )
            return
        raise ValueError("This line should not be reached")

690
691
692
693
    def weight_loader(
        self,
        param: Parameter,
        loaded_weight: torch.Tensor,
694
        loaded_shard_id: tuple[int, ...] | int | None = None,
695
    ):
696
        self.validate_shard_id(loaded_shard_id)
697
698
699
700
        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
701
702
703
704
705
706
        if isinstance(loaded_shard_id, tuple) and (
            is_gguf_weight or is_gguf_weight_type
        ):
            raise NotImplementedError(
                "Shard id with multiple indices is not supported for GGUF."
            )
707
        if is_gguf_weight_type:
708
709
710
711
712
            if loaded_shard_id is not None:
                param.data[loaded_shard_id].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
                param.shard_weight_type = {
713
                    i: loaded_weight.item() for i, _ in enumerate(self.output_sizes)
714
                }
715
716
            return

717
718
        if is_gguf_weight:
            output_dim = getattr(param, "output_dim", None)
719
720
            shard_size = loaded_weight.size(output_dim) // self.tp_size
            start_idx = self.tp_rank * shard_size
721

722
            if loaded_shard_id is not None:
723
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
724
725
726
727
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
728

729
730
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
731
732
        # Special case for per-tensor scale to load scalar into fused array.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
733

734
        if loaded_shard_id is None or isinstance(loaded_shard_id, tuple):
735
736
            # Loaded weight is already fused on disk (mlp).
            # (e.g., Phi-3's gate_up_proj).
737
            if output_dim is None:
738
                if needs_scalar_to_array:
739
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
740
741
                        param_data, loaded_weight, 0
                    )
742

743
744
745
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
746
747
748
749
750
751

            output_sizes = (
                self.output_sizes[loaded_shard_id[0] : loaded_shard_id[-1] + 1]
                if loaded_shard_id is not None
                else self.output_sizes
            )
752
            current_shard_offset = 0
753
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
754
755
756
757
758
            if (
                use_bitsandbytes_4bit
                and isinstance(loaded_shard_id, tuple)
                and self.tp_size > 1
            ):
759
760
                raise NotImplementedError(
                    "Shard id with multiple indices is not supported "
761
                    "for BNB quantization with TP yet."
762
                )
763
            shard_offsets: list[tuple[int, int, int]] = []
764
            for i, output_size in enumerate(output_sizes):
765
766
767
768
                shard_offsets.append((i, current_shard_offset, output_size))
                current_shard_offset += output_size
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
769
                # Special case for Quantization.
770
771
772
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
773
774
                    shard_size = shard_size // param.packed_factor
                    shard_offset = shard_offset // param.packed_factor
775
                    # Special case for Marlin.
776
                    shard_size, shard_offset = adjust_marlin_shard(
777
778
                        param, shard_size, shard_offset
                    )
779

780
                if use_bitsandbytes_4bit:
781
782
783
784
785
786
787
                    index = list(itertools.accumulate([0] + self.output_sizes))
                    orig_offsets = {
                        str(i): (index[i], size)
                        for i, size in enumerate(self.output_sizes)
                    }
                    orig_offsets["total"] = (self.output_size, 0)
                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
788
789
                        param, orig_offsets, str(shard_id)
                    )
790

791
                loaded_weight_shard = loaded_weight.narrow(
792
793
                    output_dim, shard_offset, shard_size
                )
794
795
796
797
798
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id < len(self.output_sizes)
        if output_dim is not None:
799
800
            shard_offset = sum(self.output_sizes[:loaded_shard_id])
            shard_size = self.output_sizes[loaded_shard_id]
801
802
            shard_offset //= self.tp_size
            shard_size //= self.tp_size
803
804
805
806
807
808
809

            if isinstance(param, BlockQuantScaleParameter):
                weight_block_size = getattr(self, "weight_block_size", None)
                shard_size, shard_offset = adjust_block_scale_shard(
                    weight_block_size, shard_size, shard_offset
                )

810
            # Special case for quantization.
811
812
813
814
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
815
816
                shard_size = shard_size // param.packed_factor
                shard_offset = shard_offset // param.packed_factor
817
                # Special case for Marlin.
818
                shard_size, shard_offset = adjust_marlin_shard(
819
820
                    param, shard_size, shard_offset
                )
821

822
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
823
824
825
826
827
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

828
            if use_bitsandbytes_4bit:
829
830
831
832
833
834
835
836
                index = list(itertools.accumulate([0] + self.output_sizes))
                orig_offsets = {
                    str(i): (index[i], size) for i, size in enumerate(self.output_sizes)
                }
                orig_offsets["total"] = (self.output_size, 0)
                shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
                    param, orig_offsets, str(loaded_shard_id)
                )
837
            param_data = param_data.narrow(output_dim, shard_offset, shard_size)
838
            start_idx = self.tp_rank * shard_size
839
            if not is_sharded_weight:
840
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
841
842
843
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
844
845
                param_data, loaded_weight, loaded_shard_id
            )
846

847
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
848
849
850
851
852
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "MergedColumnParallelLinear, assume the weight is "
853
854
                    "the same for all partitions."
                )
855

856
857
858
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

859
    def _load_fused_module_from_checkpoint(
860
861
862
863
        self,
        param: BasevLLMParameter,
        loaded_weight: torch.Tensor,
        output_sizes: list[int] | None = None,
864
    ):
865
866
867
        """
        Handle special case for models where MLP layers are already
        fused on disk. In this case, we have no shard id. This function
868
        determines the shard id by splitting these layers and then calls
869
870
871
872
873
874
875
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """

        current_shard_offset = 0
876
        shard_offsets: list[tuple[int, int, int]] = []
877
878
        output_sizes = output_sizes or self.output_sizes
        for i, output_size in enumerate(output_sizes):
879
880
881
882
883
884
885
            shard_offsets.append((i, current_shard_offset, output_size))
            current_shard_offset += output_size

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
886
887
888
889
890
891
892
893
894
895
896
            if (
                isinstance(param, (PackedColumnParameter, PackedvLLMParameter))
                and param.packed_dim == param.output_dim
            ):
                shard_size, shard_offset = param.adjust_shard_indexes_for_packing(
                    shard_size=shard_size, shard_offset=shard_offset
                )

            loaded_weight_shard = loaded_weight.narrow(
                param.output_dim, shard_offset, shard_size
            )
897
898
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

899
900
901
902
    def weight_loader_v2(
        self,
        param: BasevLLMParameter,
        loaded_weight: torch.Tensor,
903
        loaded_shard_id: tuple[int, ...] | int | None = None,
904
    ):
905
        self.validate_shard_id(loaded_shard_id)
906
        if loaded_shard_id is None or isinstance(loaded_shard_id, tuple):
907
            if isinstance(param, PerTensorScaleParameter):
908
                param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0)
909
                return
910
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
911
                param.load_merged_column_weight(loaded_weight=loaded_weight)
912
                return
913
914
915
916
917
918
919
920
921
922
923
            output_sizes = (
                [self.output_sizes[idx] for idx in loaded_shard_id]
                if loaded_shard_id
                else None
            )
            if isinstance(param, BlockQuantScaleParameter):
                weight_block_size = getattr(self, "weight_block_size", None)
                output_sizes = [
                    adjust_block_scale_shard(weight_block_size, size, 0)[0]
                    for size in (output_sizes or self.output_sizes)
                ]
924
            # TODO: @dsikka - move to parameter.py
925
926
927
            self._load_fused_module_from_checkpoint(
                param, loaded_weight, output_sizes=output_sizes
            )
928
929
930
931
            return

        assert loaded_shard_id < len(self.output_sizes)

932
933
        shard_offset = sum(self.output_sizes[:loaded_shard_id])
        shard_size = self.output_sizes[loaded_shard_id]
934
935
        shard_offset //= self.tp_size
        shard_size //= self.tp_size
936

937
        if isinstance(param, BlockQuantScaleParameter):
938
939
940
            weight_block_size = getattr(self, "weight_block_size", None)
            shard_size, shard_offset = adjust_block_scale_shard(
                weight_block_size, shard_size, shard_offset
941
            )
942

943
944
945
946
947
948
949
        param.load_merged_column_weight(
            loaded_weight=loaded_weight,
            shard_id=loaded_shard_id,
            shard_offset=shard_offset,
            shard_size=shard_size,
            tp_rank=self.tp_rank,
        )
950

951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972

class QKVParallelLinear(ColumnParallelLinear):
    """Linear layers for the attention's QKV transformation.

    Linear layers for the linear transformation of the query, key, and value
    vectors in the attention layer. The weight matrix is concatenated along
    the output dimension. The layer is parallelized along the head dimension.
    When the number of key/value heads is smaller than the number of query
    heads (e.g., multi-query/grouped-query attention), the key/value head may
    be replicated while the query heads are partitioned.

    Args:
        hidden_size: input hidden state size of the transformer.
        head_size: size of each attention head.
        total_num_heads: total number of attention query heads.
        total_num_kv_heads: total number of attention key/value heads. If
                            None, assume total_num_kv_heads = total_num_heads.
        bias: If true, add bias.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
973
        quant_config: Quantization configure.
974
975
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
976
        return_bias: If true, return bias together with outputs in forward pass.
977
        disable_tp: If true, weights matrix won't be sharded through tp rank.
978
979
    """

980
981
982
983
984
    def __init__(
        self,
        hidden_size: int,
        head_size: int,
        total_num_heads: int,
985
        total_num_kv_heads: int | None = None,
986
987
        bias: bool = True,
        skip_bias_add: bool = False,
988
989
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
990
991
992
        prefix: str = "",
        *,
        return_bias: bool = True,
993
        disable_tp: bool = False,
994
        v_head_size: int | None = None,
995
    ):
996
997
        self.hidden_size = hidden_size
        self.head_size = head_size
998
        self.v_head_size = v_head_size if v_head_size is not None else head_size
999
1000
1001
1002
1003
        self.total_num_heads = total_num_heads
        if total_num_kv_heads is None:
            total_num_kv_heads = total_num_heads
        self.total_num_kv_heads = total_num_kv_heads
        # Divide the weight matrix along the last dimension.
1004
        tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
1005
1006
1007
        self.num_heads = divide(self.total_num_heads, tp_size)
        if tp_size >= self.total_num_kv_heads:
            self.num_kv_heads = 1
1008
            self.num_kv_head_replicas = divide(tp_size, self.total_num_kv_heads)
1009
1010
1011
1012
        else:
            self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
            self.num_kv_head_replicas = 1
        input_size = self.hidden_size
1013
        output_size = (
1014
1015
1016
1017
            self.num_heads * self.head_size
            + self.num_kv_heads * self.head_size
            + self.num_kv_heads * self.v_head_size
        ) * tp_size
1018
1019
1020
        self.output_sizes = [
            self.num_heads * self.head_size * tp_size,  # q_proj
            self.num_kv_heads * self.head_size * tp_size,  # k_proj
1021
            self.num_kv_heads * self.v_head_size * tp_size,  # v_proj
James Fleming's avatar
James Fleming committed
1022
1023
        ]

1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
        super().__init__(
            input_size=input_size,
            output_size=output_size,
            bias=bias,
            gather_output=False,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            quant_config=quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
1036

1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
    def validate_shard_id(self, loaded_shard_id: str | None):
        if loaded_shard_id is None:
            return
        if isinstance(loaded_shard_id, str):
            if loaded_shard_id not in ["q", "k", "v"]:
                raise ValueError(
                    "Shard id for QKVParallelLinear should be 'q', 'k', or 'v', "
                    f"got shard id {loaded_shard_id}."
                )
            return
        raise ValueError("This line should not be reached")

1049
1050
1051
1052
1053
    def _get_shard_offset_mapping(self, loaded_shard_id: str):
        shard_offset_mapping = {
            "q": 0,
            "k": self.num_heads * self.head_size,
            "v": (self.num_heads + self.num_kv_heads) * self.head_size,
1054
1055
            "total": (self.num_heads + self.num_kv_heads) * self.head_size
            + self.num_kv_heads * self.v_head_size,
1056
1057
1058
1059
1060
1061
1062
        }
        return shard_offset_mapping.get(loaded_shard_id)

    def _get_shard_size_mapping(self, loaded_shard_id: str):
        shard_size_mapping = {
            "q": self.num_heads * self.head_size,
            "k": self.num_kv_heads * self.head_size,
1063
            "v": self.num_kv_heads * self.v_head_size,
1064
1065
1066
        }
        return shard_size_mapping.get(loaded_shard_id)

1067
1068
1069
    def _load_fused_module_from_checkpoint(
        self, param: BasevLLMParameter, loaded_weight: torch.Tensor
    ):
1070
        """
1071
        Handle special case for models where QKV layers are already
1072
        fused on disk. In this case, we have no shard id. This function
1073
        determines the shard id by splitting these layers and then calls
1074
1075
1076
1077
1078
1079
1080
1081
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """
        shard_offsets = [
            # (shard_id, shard_offset, shard_size)
            ("q", 0, self.total_num_heads * self.head_size),
1082
1083
1084
1085
1086
1087
1088
1089
            (
                "k",
                self.total_num_heads * self.head_size,
                self.total_num_kv_heads * self.head_size,
            ),
            (
                "v",
                (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
1090
                self.total_num_kv_heads * self.v_head_size,
1091
            ),
1092
1093
1094
1095
1096
1097
        ]

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
            if (
                isinstance(param, (PackedColumnParameter, PackedvLLMParameter))
                and param.packed_dim == param.output_dim
            ):
                shard_size, shard_offset = param.adjust_shard_indexes_for_packing(
                    shard_size=shard_size, shard_offset=shard_offset
                )

            loaded_weight_shard = loaded_weight.narrow(
                param.output_dim, shard_offset, shard_size
            )
1109
1110
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

1111
1112
1113
1114
    def weight_loader_v2(
        self,
        param: BasevLLMParameter,
        loaded_weight: torch.Tensor,
1115
        loaded_shard_id: str | None = None,
1116
    ):
1117
        self.validate_shard_id(loaded_shard_id)
1118
        if loaded_shard_id is None:  # special case for certain models
1119
            if isinstance(param, PerTensorScaleParameter):
1120
1121
1122
                param.load_qkv_weight(
                    loaded_weight=loaded_weight, shard_id=0, tp_rank=self.tp_rank
                )
1123
                return
1124
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
1125
                param.load_qkv_weight(loaded_weight=loaded_weight, tp_rank=self.tp_rank)
1126
                return
1127
            # TODO: @dsikka - move to parameter.py
1128
1129
1130
1131
1132
1133
1134
1135
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id in ["q", "k", "v"]

        shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
        shard_size = self._get_shard_size_mapping(loaded_shard_id)

1136
        if isinstance(param, BlockQuantScaleParameter):
1137
1138
1139
1140
            weight_block_size = getattr(self, "weight_block_size", None)
            shard_size, shard_offset = adjust_block_scale_shard(
                weight_block_size, shard_size, shard_offset
            )
1141

1142
1143
1144
1145
1146
1147
1148
1149
        param.load_qkv_weight(
            loaded_weight=loaded_weight,
            num_heads=self.num_kv_head_replicas,
            shard_id=loaded_shard_id,
            shard_offset=shard_offset,
            shard_size=shard_size,
            tp_rank=self.tp_rank,
        )
1150

1151
1152
1153
1154
    def weight_loader(
        self,
        param: Parameter,
        loaded_weight: torch.Tensor,
1155
        loaded_shard_id: str | None = None,
1156
    ):
1157
        self.validate_shard_id(loaded_shard_id)
1158
1159
1160
1161
        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
1162
        if is_gguf_weight_type:
1163
            idx_map = {"q": 0, "k": 1, "v": 2}
1164
1165
1166
1167
            if loaded_shard_id is not None:
                param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
1168
                param.shard_weight_type = {k: loaded_weight.item() for k in idx_map}
1169
1170
            return

1171
1172
        if is_gguf_weight:
            output_dim = getattr(param, "output_dim", None)
1173
1174
            shard_size = loaded_weight.size(output_dim) // self.tp_size
            start_idx = self.tp_rank * shard_size
1175

1176
            if loaded_shard_id is not None:
1177
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
1178
1179
1180
1181
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
1182

1183
1184
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
1185

1186
1187
        # Special case for per-tensor scales in fused case.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
1188

1189
        if loaded_shard_id is None:
1190
1191
            # Loaded weight is already fused on disk (qkv).
            # (e.g., Phi-3's qkv_proj).
1192
            if output_dim is None:
1193
                if needs_scalar_to_array:
1194
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
1195
1196
                        param_data, loaded_weight, 0
                    )
1197

1198
1199
1200
1201
1202
1203
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            shard_offsets = [
                # (shard_id, shard_offset, shard_size)
                ("q", 0, self.total_num_heads * self.head_size),
1204
1205
1206
1207
1208
1209
1210
1211
                (
                    "k",
                    self.total_num_heads * self.head_size,
                    self.total_num_kv_heads * self.head_size,
                ),
                (
                    "v",
                    (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
1212
                    self.total_num_kv_heads * self.v_head_size,
1213
                ),
1214
            ]
1215
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1216

1217
1218
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
1219
                # Special case for Quantized Weights.
1220
1221
1222
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
1223
1224
                    shard_size = shard_size // param.packed_factor
                    shard_offset = shard_offset // param.packed_factor
1225

1226
                    # Special case for Marlin.
1227
                    shard_size, shard_offset = adjust_marlin_shard(
1228
1229
                        param, shard_size, shard_offset
                    )
1230

1231
1232
1233
                if use_bitsandbytes_4bit:
                    orig_qkv_offsets = {
                        "q": (0, self.total_num_heads * self.head_size),
1234
1235
1236
1237
1238
1239
1240
                        "k": (
                            self.total_num_heads * self.head_size,
                            self.total_num_kv_heads * self.head_size,
                        ),
                        "v": (
                            (self.total_num_heads + self.total_num_kv_heads)
                            * self.head_size,
1241
                            self.total_num_kv_heads * self.v_head_size,
1242
1243
                        ),
                        "total": (
1244
1245
1246
                            (self.total_num_heads + self.total_num_kv_heads)
                            * self.head_size
                            + self.total_num_kv_heads * self.v_head_size,
1247
1248
                            0,
                        ),
1249
1250
1251
                    }

                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
1252
1253
                        param, orig_qkv_offsets, shard_id
                    )
1254

1255
                loaded_weight_shard = loaded_weight.narrow(
1256
1257
                    output_dim, shard_offset, shard_size
                )
1258
1259
1260
1261
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id in ["q", "k", "v"]
1262
1263

        # If output dim is defined, use the default loading process.
1264
1265
1266
1267
1268
1269
1270
1271
        if output_dim is not None:
            if loaded_shard_id == "q":
                shard_offset = 0
                shard_size = self.num_heads * self.head_size
            elif loaded_shard_id == "k":
                shard_offset = self.num_heads * self.head_size
                shard_size = self.num_kv_heads * self.head_size
            elif loaded_shard_id == "v":
1272
                shard_offset = (self.num_heads + self.num_kv_heads) * self.head_size
1273
                shard_size = self.num_kv_heads * self.v_head_size
1274
1275
1276
1277
1278
1279
1280

            if isinstance(param, BlockQuantScaleParameter):
                weight_block_size = getattr(self, "weight_block_size", None)
                shard_size, shard_offset = adjust_block_scale_shard(
                    weight_block_size, shard_size, shard_offset
                )

1281
            # Special case for Quantized Weights.
1282
1283
1284
1285
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
1286
1287
                shard_size = shard_size // param.packed_factor
                shard_offset = shard_offset // param.packed_factor
1288

1289
                # Special case for Marlin.
1290
                shard_size, shard_offset = adjust_marlin_shard(
1291
1292
                    param, shard_size, shard_offset
                )
1293

1294
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1295
1296
1297
1298
1299
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

1300
            if use_bitsandbytes_4bit:
1301
1302
                orig_qkv_offsets = {
                    "q": (0, self.num_heads * self.head_size),
1303
1304
1305
1306
1307
1308
                    "k": (
                        self.num_heads * self.head_size,
                        self.num_kv_heads * self.head_size,
                    ),
                    "v": (
                        (self.num_heads + self.num_kv_heads) * self.head_size,
1309
                        self.num_kv_heads * self.v_head_size,
1310
1311
                    ),
                    "total": (
1312
1313
                        (self.num_heads + self.num_kv_heads) * self.head_size
                        + self.num_kv_heads * self.v_head_size,
1314
1315
                        0,
                    ),
1316
                }
1317
                shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
1318
1319
                    param, orig_qkv_offsets, loaded_shard_id
                )
1320

1321
            param_data = param_data.narrow(output_dim, shard_offset, shard_size)
1322
            if loaded_shard_id == "q":
1323
                shard_rank = self.tp_rank
1324
            else:
1325
1326
                shard_rank = self.tp_rank // self.num_kv_head_replicas
            start_idx = shard_rank * shard_size
1327

1328
            if not is_sharded_weight:
1329
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
1330

1331
1332
1333
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
1334
1335
                param_data, loaded_weight, loaded_shard_id
            )
1336
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
1337
1338
1339
1340
1341
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "QKVParallelLinear, assume the weight is the same "
1342
1343
                    "for all partitions."
                )
1344

1345
1346
1347
1348
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)


1349
# --8<-- [start:row_parallel_linear]
1350
@PluggableLayer.register("row_parallel_linear")
1351
class RowParallelLinear(LinearBase):
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
    """Linear layer with row parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its first dimension and X along its second dimension as:
               -   -
              | A_1 |
              | .   |
          A = | .   |        X = [X_1, ..., X_p]
              | .   |
              | A_p |
               -   -
    Arguments:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias. Note that bias is not parallelized.
        input_is_parallel: If true, we assume that the input is already
                           split across the GPUs and we do not split
                           again.
        skip_bias_add: This was added to enable performance optimization where
                       bias can be fused with other element-wise operations.
                       We skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
1374
1375
1376
        reduce_results: If true, call all-reduce on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y = X_iA_i
1377
        quant_config: Quantization configure.
1378
1379
1380
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.down_proj)
        return_bias: If true, return bias together with outputs in forward pass.
1381
        disable_tp: If true, weights matrix won't be sharded through tp rank.
1382
1383
    """

1384
1385
    # --8<-- [end:row_parallel_linear]

1386
1387
1388
1389
1390
1391
1392
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        input_is_parallel: bool = True,
        skip_bias_add: bool = False,
1393
        params_dtype: torch.dtype | None = None,
1394
        reduce_results: bool = True,
1395
        quant_config: QuantizationConfig | None = None,
1396
1397
1398
        prefix: str = "",
        *,
        return_bias: bool = True,
1399
        disable_tp: bool = False,
1400
    ):
1401
        # Divide the weight matrix along the first dimension.
1402
1403
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
1404
1405
1406
1407
        self.input_size_per_partition = divide(input_size, self.tp_size)
        self.output_size_per_partition = output_size
        self.output_partition_sizes = [output_size]

1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
        super().__init__(
            input_size,
            output_size,
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
1418

1419
1420
1421
        self.input_is_parallel = input_is_parallel
        self.reduce_results = reduce_results

1422
        assert self.quant_method is not None
1423
1424
1425
        self.quant_method.create_weights(
            layer=self,
            input_size_per_partition=self.input_size_per_partition,
1426
            output_partition_sizes=self.output_partition_sizes,
1427
1428
1429
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
1430
            weight_loader=(
1431
1432
1433
1434
1435
                self.weight_loader_v2
                if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
                else self.weight_loader
            ),
        )
1436
        if not reduce_results and (bias and not skip_bias_add):
1437
1438
1439
1440
            raise ValueError(
                "When not reduce the results, adding bias to the "
                "results can lead to incorrect results"
            )
1441
1442

        if bias:
1443
1444
1445
1446
1447
1448
1449
1450
            self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype))
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
1451
1452
        else:
            self.register_parameter("bias", None)
1453
        self.update_param_tp_status()
1454
1455
1456

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        input_dim = getattr(param, "input_dim", None)
1457
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1458
1459
1460
1461
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
1462
1463
1464
1465
1466
1467
1468
1469
1470

        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
1471
1472
            weight_shape = list(loaded_weight.shape)
            if input_dim:
1473
                weight_shape[input_dim] = weight_shape[input_dim] // self.tp_size
1474
            param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
1475

1476
        param_data = param.data
1477
        if input_dim is not None and not is_sharded_weight:
1478
            shard_size = param_data.shape[input_dim]
1479
            start_idx = self.tp_rank * shard_size
1480
            loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)
1481

1482
1483
1484
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
1485
1486
            loaded_weight = loaded_weight.reshape(1)

1487
1488
1489
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

1490
    def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor):
1491
1492
1493
1494
1495
1496
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)

1497
1498
        param.load_row_parallel_weight(loaded_weight=loaded_weight)

1499
    def forward(
1500
1501
        self,
        input_,
1502
    ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
1503
1504
1505
        if self.input_is_parallel:
            input_parallel = input_
        else:
Jiayi Yan's avatar
Jiayi Yan committed
1506
            split_input = split_tensor_along_last_dim(
1507
1508
                input_, num_partitions=self.tp_size
            )
Jiayi Yan's avatar
Jiayi Yan committed
1509
            input_parallel = split_input[self.tp_rank].contiguous()
1510
1511

        # Matrix multiply.
1512
        assert self.quant_method is not None
1513
1514
1515
        # Only fuse bias add into GEMM for rank 0 (this ensures that
        # bias will not get added more than once in TP>1 case)
        bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
1516
1517
        output_parallel = self.quant_method.apply(self, input_parallel, bias_)

1518
        if self.reduce_results and self.tp_size > 1:
1519
            output = tensor_model_parallel_all_reduce(output_parallel)
1520
        else:
1521
1522
            output = output_parallel

1523
1524
        if not self.return_bias:
            return output
1525
        output_bias = self.bias if self.skip_bias_add else None
1526
        return output, output_bias
1527
1528

    def extra_repr(self) -> str:
1529
        s = f"in_features={self.input_size_per_partition}"
1530
1531
1532
1533
1534
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={self.tp_size}"
        s += f", reduce_results={self.reduce_results}"
        return s