linear.py 60.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
from abc import abstractmethod
6
from typing import Any
7
8

import torch
9
from torch.nn.parameter import Parameter, UninitializedParameter
10

11
12
13
14
15
16
17
18
from vllm.distributed import (
    divide,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
    split_tensor_along_last_dim,
    tensor_model_parallel_all_gather,
    tensor_model_parallel_all_reduce,
)
19
from vllm.logger import init_logger
20
from vllm.model_executor.custom_op import PluggableLayer
21
22
23
24
from vllm.model_executor.layers.batch_invariant import (
    linear_batch_invariant,
    vllm_is_batch_invariant,
)
25
from vllm.model_executor.layers.quantization.base_config import (
26
27
28
    QuantizationConfig,
    QuantizeMethodBase,
)
29
30
31
32
from vllm.model_executor.layers.utils import (
    dispatch_unquantized_gemm,
    is_layer_moe_router_gate,
)
33
34
35
36
37
38
39
40
41
from vllm.model_executor.parameter import (
    BasevLLMParameter,
    BlockQuantScaleParameter,
    ModelWeightParameter,
    PackedColumnParameter,
    PackedvLLMParameter,
    PerTensorScaleParameter,
    RowvLLMParameter,
)
42
from vllm.model_executor.utils import set_weight_attrs
43
from vllm.platforms import current_platform
44
45
46

logger = init_logger(__name__)

47
WEIGHT_LOADER_V2_SUPPORTED = [
48
    "UnquantizedLinearMethod",
49
    "CompressedTensorsLinearMethod",
50
    "CompressedTensorsLinearTransformMethod",
51
52
53
54
55
56
57
58
59
60
    "AWQMarlinLinearMethod",
    "AWQLinearMethod",
    "GPTQMarlinLinearMethod",
    "Fp8LinearMethod",
    "MarlinLinearMethod",
    "GPTQMarlin24LinearMethod",
    "TPUInt8LinearMethod",
    "GPTQLinearMethod",
    "FBGEMMFp8LinearMethod",
    "ModelOptFp8LinearMethod",
61
62
    "ModelOptFp8PcPtLinearMethod",
    "ModelOptFp8PbWoLinearMethod",
63
64
    "QuarkLinearMethod",
    "ModelOptNvFp4LinearMethod",
65
    "PetitNvFp4LinearMethod",
66
]
67

68

69
70
71
72
73
74
def adjust_marlin_shard(
    param: Parameter,
    shard_size: int,
    shard_offset: int,
) -> tuple[int, int]:
    marlin_tile_size: int | None = getattr(param, "marlin_tile_size", None)
75
76
77
78
79
80
    if marlin_tile_size is None:
        return shard_size, shard_offset

    return shard_size * marlin_tile_size, shard_offset * marlin_tile_size


81
82
83
84
85
def adjust_block_scale_shard(
    weight_block_size: tuple[int, ...] | None,
    shard_size: int,
    shard_offset: int,
) -> tuple[int, int]:
86
87
88
89
90
91
92
    assert weight_block_size is not None
    block_n = weight_block_size[0]
    shard_offset = (shard_offset + block_n - 1) // block_n
    shard_size = (shard_size + block_n - 1) // block_n
    return shard_size, shard_offset


93
def adjust_bitsandbytes_4bit_shard(
94
95
96
    param: Parameter,
    shard_offsets: dict[str, tuple[int, int]],
    loaded_shard_id: str,
97
) -> tuple[int, int]:
98
99
    """Adjust the quantization offsets and sizes for BitsAndBytes sharding."""

100
101
    total, _ = shard_offsets["total"]
    orig_offset, orig_size = shard_offsets[loaded_shard_id]
102
103
104
105
106
107
108
109

    quantized_total = param.data.shape[0]
    quantized_offset = orig_offset * quantized_total // total
    quantized_size = orig_size * quantized_total // total

    return quantized_size, quantized_offset


110
111
112
113
114
def adjust_scalar_to_fused_array(
    param_data: torch.Tensor,
    loaded_weight: torch.Tensor,
    shard_id: int | str,
) -> tuple[torch.Tensor, torch.Tensor]:
115
116
    """For fused modules (QKV and MLP) we have an array of length
    N that holds 1 scale for each "logical" matrix. So the param
117
118
    is an array of length N. The loaded_weight corresponds to
    one of the shards on disk. Here, we slice the param based on
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
    the shard_id for loading.
    """
    qkv_idxs = {"q": 0, "k": 1, "v": 2}

    if isinstance(shard_id, str):
        shard_id = qkv_idxs[shard_id]
    elif not isinstance(shard_id, int):
        raise ValueError(f"Unknown Shard Id {shard_id}")

    # AutoFP8 scales do not have a shape
    # compressed-tensors scales do have a shape
    if len(loaded_weight.shape) != 0:
        assert loaded_weight.shape[0] == 1
        loaded_weight = loaded_weight[0]

134
    return param_data[shard_id], loaded_weight
135
136


137
138
# TODO(Isotr0py): We might need a more flexible structure to handle
# bitsandbytes shard offsets.
139
140
141
def left_shift_bitsandbytes_4bit_shard(
    bnb_weight_attrs: dict[str, Any],
) -> tuple[dict[str, Any], dict[str, Any]]:
142
143
144
145
146
    """
    Separate the BitsAndBytes 4-bit shard.

    For example, given bnb weight attributes as below:
    {
147
        'bnb_shard_offsets': array([0, 4, 8, 16]),
148
149
150
151
152
        'bnb_quant_state': {0: ..., 1: ..., 2: ...},
    }

    The function will return:
    {
153
        'bnb_shard_offsets': array([0, 4]),
154
155
156
157
158
159
160
161
162
163
164
165
166
167
        'bnb_quant_state': {0: ...},
    }
    and
    {
        'bnb_shard_offsets': array([0, 4, 12]),
        'bnb_quant_state': {0: ..., 1: ...},
    }
    """
    shard_offsets = bnb_weight_attrs["bnb_shard_offsets"]
    offset_l = shard_offsets[:2]
    offset_r = shard_offsets[1:] - shard_offsets[1]
    quant_state_l = {0: bnb_weight_attrs["bnb_quant_state"][0]}
    quant_state_r = {
        i - 1: bnb_weight_attrs["bnb_quant_state"][i]
168
        for i in range(1, len(shard_offsets) - 1)
169
170
171
172
173
174
    }
    left = dict(bnb_shard_offsets=offset_l, bnb_quant_state=quant_state_l)
    right = dict(bnb_shard_offsets=offset_r, bnb_quant_state=quant_state_r)
    return left, right


175
class LinearMethodBase(QuantizeMethodBase):
176
177
178
    """Base class for different (maybe quantized) linear methods."""

    @abstractmethod
179
180
181
182
183
184
185
186
187
188
189
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        """Create weights for a linear layer.
190
           The weights will be set as attributes of the layer.
191

192
193
194
        Args:
            layer: The layer that is using the LinearMethodBase factory.
            input_size_per_partition: Size of the weight input dim on rank X.
195
            output_partition_sizes: Sizes of the output dim of each logical
196
197
198
199
200
201
                weight on rank X. E.g., output_partition_sizes for QKVLinear
                is a list contains the width of Wq, Wk, Wv on rank X.
            input_size: Size of the input dim of the weight across all ranks.
            output_size: Size of the output dim of the weight across all ranks.
            params_dtype: Datatype of the parameters.
        """
202
203
204
        raise NotImplementedError

    @abstractmethod
205
206
207
208
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
209
        bias: torch.Tensor | None = None,
210
    ) -> torch.Tensor:
211
212
        """Apply the weights in layer to the input tensor.
        Expects create_weights to have been called before on the layer."""
213
214
215
216
        raise NotImplementedError


class UnquantizedLinearMethod(LinearMethodBase):
217
    """Linear method without quantization."""
218

219
220
221
222
223
224
225
226
227
228
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
229
230
231
232
        # This method creates unquantized linear weights.
        # The weights are not quantized, and they are not sharded.
        # The amount of memory allocated for the weights is
        # sum(output_partition_sizes) * input_size_per_partition.
233
234
235
236
237
238
239
240
241
242
243
        weight_loader = extra_weight_attrs.pop("weight_loader")
        weight = ModelWeightParameter(
            data=torch.empty(
                sum(output_partition_sizes),
                input_size_per_partition,
                dtype=params_dtype,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
244

245
246
        layer.register_parameter("weight", weight)
        set_weight_attrs(weight, extra_weight_attrs)
247

248
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
249
        if current_platform.is_cpu():
250
            from vllm.model_executor.layers.utils import dispatch_cpu_unquantized_gemm
251

252
            dispatch_cpu_unquantized_gemm(layer, remove_weight=True)
253

254
255
256
257
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
258
        bias: torch.Tensor | None = None,
259
    ) -> torch.Tensor:
260
261
262
263
264
265
        if (
            vllm_is_batch_invariant()
            and current_platform.is_cuda_alike()
            and is_layer_moe_router_gate(getattr(layer, "prefix", ""))
        ):
            return linear_batch_invariant(x, layer.weight, bias)
266
        return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
267
268


269
class LinearBase(PluggableLayer):
270
    """Base linear layer.
271
272
273
274
275
276

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
277
        quant_config: Quantization configure.
278
        prefix: Prefix for parameter names.
279
        return_bias: If true, return bias together with outputs in forward pass.
280
        disable_tp: If true, tensor parallelism will be disabled for this layer.
281
282
283
284
285
286
287
    """

    def __init__(
        self,
        input_size: int,
        output_size: int,
        skip_bias_add: bool = False,
288
289
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
290
        prefix: str = "",
291
292
        *,
        return_bias: bool = True,
293
        disable_tp: bool = False,
294
295
296
297
298
299
300
301
302
303
    ):
        super().__init__()

        # Keep input parameters
        self.input_size = input_size
        self.output_size = output_size
        self.skip_bias_add = skip_bias_add
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype
304
305
        self.quant_config = quant_config
        self.prefix = prefix
306
        self.allow_fp8_block_shape_mismatch = False
307
        if quant_config is None:
308
            self.quant_method: QuantizeMethodBase | None = UnquantizedLinearMethod()
309
        else:
310
            self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
311
        self.return_bias = return_bias
312
        self.disable_tp = disable_tp
313
314
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
315

316
    def update_param_tp_status(self):
317
318
319
320
        for param in self.parameters():
            if isinstance(param, BasevLLMParameter):
                param.tp_rank = self.tp_rank
                param.tp_size = self.tp_size
321
322


323
# --8<-- [start:replicated_linear]
324
@PluggableLayer.register("replicated_linear")
325
326
327
328
329
330
331
332
333
334
class ReplicatedLinear(LinearBase):
    """Replicated linear layer.

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        bias: If true, add bias.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
        quant_config: Quantization configure.
335
336
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
337
        return_bias: If true, return bias together with outputs in forward pass.
338
        disable_tp: Take no effect for replicated linear layers.
339
340
    """

341
342
    # --8<-- [end:replicated_linear]

343
344
345
346
347
348
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        skip_bias_add: bool = False,
349
350
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
351
352
353
        prefix: str = "",
        *,
        return_bias: bool = True,
354
        disable_tp: bool = False,
355
    ):
356
357
358
359
360
361
        # If MergedReplicatedLinear, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = self.output_sizes
        else:
            self.output_partition_sizes = [output_size]

362
363
364
365
366
367
368
369
370
371
        super().__init__(
            input_size,
            output_size,
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
372

373
374
        # All the linear layer supports quant method.
        assert self.quant_method is not None
375
376
377
378
379
380
381
382
383
        self.quant_method.create_weights(
            self,
            self.input_size,
            self.output_partition_sizes,
            self.input_size,
            self.output_size,
            self.params_dtype,
            weight_loader=self.weight_loader,
        )
384

385
386
        if bias:
            self.bias = Parameter(
387
388
389
390
391
392
393
394
395
                torch.empty(self.output_size, dtype=self.params_dtype)
            )
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
396
397
398
        else:
            self.register_parameter("bias", None)

399
400
401
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        # If the weight on disk does not have a shape, give it one
        # (such scales for AutoFp8).
402
403
404
405
406
407
408
409
410
411
412
        # Special case for GGUF

        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
            param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)

413
414
415
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)

416
417
        assert param.size() == loaded_weight.size(), (
            f"Tried to load weights of size {loaded_weight.size()}"
418
419
            f"to a parameter of size {param.size()}"
        )
420
421
        param.data.copy_(loaded_weight)

422
    def forward(
423
424
        self,
        x: torch.Tensor,
425
    ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
426
        bias = self.bias if not self.skip_bias_add else None
427
        assert self.quant_method is not None
428

429
        output = self.quant_method.apply(self, x, bias)
430

431
432
        if not self.return_bias:
            return output
433
        output_bias = self.bias if self.skip_bias_add else None
434
435
        return output, output_bias

436
437
438
439
440
441
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        return s

442

443
# --8<-- [start:column_parallel_linear]
444
@PluggableLayer.register("column_parallel_linear")
445
class ColumnParallelLinear(LinearBase):
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
    """Linear layer with column parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its second dimension as A = [A_1, ..., A_p].

    Args:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y_i = XA_i
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
462
        quant_config: Quantization configure.
463
        prefix: The name of the layer in the state dict, including all parents
464
                        (e.g. model.layers.0.qkv_proj)
465
466
        return_bias: If true, return bias together with outputs in forward pass.
        disable_tp: If true, weights matrix won't be sharded through tp rank.
467
468
    """

469
470
    # --8<-- [end:column_parallel_linear]

471
472
473
474
475
476
477
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
478
479
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
480
481
482
        prefix: str = "",
        *,
        return_bias: bool = True,
483
        disable_tp: bool = False,
484
    ):
485
        # Divide the weight matrix along the last dimension.
486
487
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
488
489
        self.input_size_per_partition = input_size
        self.output_size_per_partition = divide(output_size, self.tp_size)
490
491
492
493
        self.output_partition_sizes = [self.output_size_per_partition]
        # If QKV or MergedColumn, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = [
494
                divide(output_size, self.tp_size) for output_size in self.output_sizes
495
496
            ]

497
498
499
500
501
502
503
504
505
506
        super().__init__(
            input_size,
            output_size,
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
507

508
        self._maybe_allow_fp8_block_shape_mismatch()
509
510
511
        self.gather_output = gather_output

        assert self.quant_method is not None
512
513
        self.quant_method.create_weights(
            layer=self,
514
            input_size_per_partition=self.input_size_per_partition,
515
516
517
518
            output_partition_sizes=self.output_partition_sizes,
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
519
            weight_loader=(
520
521
522
523
524
                self.weight_loader_v2
                if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
                else self.weight_loader
            ),
        )
525
526
        if bias:
            self.bias = Parameter(
527
528
529
530
531
532
533
534
535
                torch.empty(self.output_size_per_partition, dtype=params_dtype)
            )
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
536
537
        else:
            self.register_parameter("bias", None)
538
        self.update_param_tp_status()
539

540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
    def _maybe_allow_fp8_block_shape_mismatch(self) -> None:
        quant_config = getattr(self, "quant_config", None)
        weight_block = getattr(quant_config, "weight_block_size", None)
        if (
            weight_block is None
            or len(weight_block) < 1
            or len(self.output_partition_sizes) <= 1
        ):
            return

        try:
            block_n = int(weight_block[0])
        except (ValueError, TypeError):
            return

        if block_n <= 0:
            return

        if any(size % block_n != 0 for size in self.output_partition_sizes):
            self.allow_fp8_block_shape_mismatch = True
            logger.debug(
                "Allowing FP8 block shape mismatch for %s (block_n=%d, partitions=%s)",
                getattr(self, "prefix", "<unknown>"),
                block_n,
                self.output_partition_sizes,
            )

567
568
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        output_dim = getattr(param, "output_dim", None)
569

570
571
572
573
574
575
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

576
577
578
579
580
581
582
583
        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
584
585
            final_shape = list(loaded_weight.shape)
            if output_dim is not None:
586
                assert final_shape[output_dim] % self.tp_size == 0
587
                final_shape[output_dim] = final_shape[output_dim] // self.tp_size
588
            param.materialize(final_shape, dtype=loaded_weight.dtype)
589

590
        param_data = param.data
591
        if output_dim is not None and not is_sharded_weight:
592
            shard_size = param_data.shape[output_dim]
593
            start_idx = self.tp_rank * shard_size
594
            loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
595
596
597
598
599

        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)
600

601
602
603
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

604
    def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor):
605
606
607
608
609
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)
610
611
        param.load_column_parallel_weight(loaded_weight=loaded_weight)

612
    def forward(
613
614
        self,
        input_,
615
    ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
616
617
618
        bias = self.bias if not self.skip_bias_add else None

        # Matrix multiply.
619
        assert self.quant_method is not None
620
        output_parallel = self.quant_method.apply(self, input_, bias)
621

622
        if self.gather_output and self.tp_size > 1:
623
624
625
626
            # All-gather across the partitions.
            output = tensor_model_parallel_all_gather(output_parallel)
        else:
            output = output_parallel
627

628
629
        if not self.return_bias:
            return output
630
        output_bias = self.bias if self.skip_bias_add else None
631
632
        return output, output_bias

633
634
635
636
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size_per_partition}"
        s += f", bias={self.bias is not None}"
637
        s += f", tp_size={self.tp_size}"
638
639
640
        s += f", gather_output={self.gather_output}"
        return s

641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659

class MergedColumnParallelLinear(ColumnParallelLinear):
    """Packed linear layers with column parallelism.

    Similar to ColumnParallelLinear, but the weight matrix is concatenated
    along the output dimension. When the weight matrix is loaded, the
    different partitions are sharded separately.

    Args:
        input_size: input dimension of the linear layer.
        output_sizes: list of output dimensions of the linear layer.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make the output
                       available to all GPUs, otherwise, every GPU will have
                       its own output.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
660
        quant_config: Quantization configure.
661
662
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
663
        return_bias: If true, return bias together with outputs in forward pass.
664
665
        disable_tp: If true, all weights matrix won't be sharded, this layer
                    will be treated as a "Replicated" MergedLinear.
666
667
    """

668
669
670
671
672
673
674
    def __init__(
        self,
        input_size: int,
        output_sizes: list[int],
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
675
676
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
677
678
679
        prefix: str = "",
        *,
        return_bias: bool = True,
680
        disable_tp: bool = False,
681
    ):
682
        self.output_sizes = output_sizes
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0

        assert all(output_size % self.tp_size == 0 for output_size in output_sizes)
        super().__init__(
            input_size=input_size,
            output_size=sum(output_sizes),
            bias=bias,
            gather_output=gather_output,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            quant_config=quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
James Fleming's avatar
James Fleming committed
699

700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
    def validate_shard_id(self, loaded_shard_id: int | tuple[int, ...] | None):
        if loaded_shard_id is None:
            return
        if isinstance(loaded_shard_id, tuple):
            for idx in loaded_shard_id:
                if not (0 <= idx < len(self.output_sizes)):
                    raise ValueError(
                        f"Shard id index {idx} should be between 0 and "
                        f"{len(self.output_sizes) - 1}. Got shard id {loaded_shard_id}."
                    )
            if len(loaded_shard_id) > 1 and any(
                b - a != 1 for a, b in zip(loaded_shard_id[:-1], loaded_shard_id[1:])
            ):
                raise ValueError(
                    "Shard id with multiple indices should be consecutive. "
                    f"Got shard id {loaded_shard_id}."
                )
            return
        elif isinstance(loaded_shard_id, int):
            if loaded_shard_id < 0 or loaded_shard_id >= len(self.output_sizes):
                raise ValueError(
                    f"Shard id should be between 0 and {len(self.output_sizes) - 1}. "
                    f"Got shard id {loaded_shard_id}."
                )
            return
        raise ValueError("This line should not be reached")

727
728
729
730
    def weight_loader(
        self,
        param: Parameter,
        loaded_weight: torch.Tensor,
731
        loaded_shard_id: tuple[int, ...] | int | None = None,
732
    ):
733
        self.validate_shard_id(loaded_shard_id)
734
735
736
737
        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
738
739
740
741
742
743
        if isinstance(loaded_shard_id, tuple) and (
            is_gguf_weight or is_gguf_weight_type
        ):
            raise NotImplementedError(
                "Shard id with multiple indices is not supported for GGUF."
            )
744
        if is_gguf_weight_type:
745
746
747
748
749
            if loaded_shard_id is not None:
                param.data[loaded_shard_id].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
                param.shard_weight_type = {
750
                    i: loaded_weight.item() for i, _ in enumerate(self.output_sizes)
751
                }
752
753
            return

754
755
        if is_gguf_weight:
            output_dim = getattr(param, "output_dim", None)
756
757
            shard_size = loaded_weight.size(output_dim) // self.tp_size
            start_idx = self.tp_rank * shard_size
758

759
            if loaded_shard_id is not None:
760
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
761
762
763
764
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
765

766
767
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
768
769
        # Special case for per-tensor scale to load scalar into fused array.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
770

771
        if loaded_shard_id is None or isinstance(loaded_shard_id, tuple):
772
773
            # Loaded weight is already fused on disk (mlp).
            # (e.g., Phi-3's gate_up_proj).
774
            if output_dim is None:
775
                if needs_scalar_to_array:
776
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
777
778
                        param_data, loaded_weight, 0
                    )
779

780
781
782
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
783
784
785
786
787
788

            output_sizes = (
                self.output_sizes[loaded_shard_id[0] : loaded_shard_id[-1] + 1]
                if loaded_shard_id is not None
                else self.output_sizes
            )
789
            current_shard_offset = 0
790
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
791
792
793
794
795
            if use_bitsandbytes_4bit and isinstance(loaded_shard_id, tuple):
                raise NotImplementedError(
                    "Shard id with multiple indices is not supported "
                    "for BNB quantization yet."
                )
796
            shard_offsets: list[tuple[int, int, int]] = []
797
            for i, output_size in enumerate(output_sizes):
798
799
800
801
                shard_offsets.append((i, current_shard_offset, output_size))
                current_shard_offset += output_size
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
802
                # Special case for Quantization.
803
804
805
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
806
807
                    shard_size = shard_size // param.packed_factor
                    shard_offset = shard_offset // param.packed_factor
808
                    # Special case for Marlin.
809
                    shard_size, shard_offset = adjust_marlin_shard(
810
811
                        param, shard_size, shard_offset
                    )
812

813
                if use_bitsandbytes_4bit:
814
815
816
817
818
819
820
                    index = list(itertools.accumulate([0] + self.output_sizes))
                    orig_offsets = {
                        str(i): (index[i], size)
                        for i, size in enumerate(self.output_sizes)
                    }
                    orig_offsets["total"] = (self.output_size, 0)
                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
821
822
                        param, orig_offsets, str(shard_id)
                    )
823

824
                loaded_weight_shard = loaded_weight.narrow(
825
826
                    output_dim, shard_offset, shard_size
                )
827
828
829
830
831
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id < len(self.output_sizes)
        if output_dim is not None:
832
833
            shard_offset = sum(self.output_sizes[:loaded_shard_id])
            shard_size = self.output_sizes[loaded_shard_id]
834
835
            shard_offset //= self.tp_size
            shard_size //= self.tp_size
836
837
838
839
840
841
842

            if isinstance(param, BlockQuantScaleParameter):
                weight_block_size = getattr(self, "weight_block_size", None)
                shard_size, shard_offset = adjust_block_scale_shard(
                    weight_block_size, shard_size, shard_offset
                )

843
            # Special case for quantization.
844
845
846
847
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
848
849
                shard_size = shard_size // param.packed_factor
                shard_offset = shard_offset // param.packed_factor
850
                # Special case for Marlin.
851
                shard_size, shard_offset = adjust_marlin_shard(
852
853
                    param, shard_size, shard_offset
                )
854

855
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
856
857
858
859
860
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

861
            if use_bitsandbytes_4bit:
862
                shard_size = loaded_weight.shape[output_dim]
863
                shard_offset = loaded_weight.shape[output_dim] * loaded_shard_id
864

865
            param_data = param_data.narrow(output_dim, shard_offset, shard_size)
866
            start_idx = self.tp_rank * shard_size
867
            if not is_sharded_weight:
868
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
869
870
871
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
872
873
                param_data, loaded_weight, loaded_shard_id
            )
874

875
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
876
877
878
879
880
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "MergedColumnParallelLinear, assume the weight is "
881
882
                    "the same for all partitions."
                )
883

884
885
886
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

887
    def _load_fused_module_from_checkpoint(
888
889
890
891
        self,
        param: BasevLLMParameter,
        loaded_weight: torch.Tensor,
        output_sizes: list[int] | None = None,
892
    ):
893
894
895
        """
        Handle special case for models where MLP layers are already
        fused on disk. In this case, we have no shard id. This function
896
        determines the shard id by splitting these layers and then calls
897
898
899
900
901
902
903
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """

        current_shard_offset = 0
904
        shard_offsets: list[tuple[int, int, int]] = []
905
906
        output_sizes = output_sizes or self.output_sizes
        for i, output_size in enumerate(output_sizes):
907
908
909
910
911
912
913
            shard_offsets.append((i, current_shard_offset, output_size))
            current_shard_offset += output_size

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
914
915
916
917
918
919
920
921
922
923
924
            if (
                isinstance(param, (PackedColumnParameter, PackedvLLMParameter))
                and param.packed_dim == param.output_dim
            ):
                shard_size, shard_offset = param.adjust_shard_indexes_for_packing(
                    shard_size=shard_size, shard_offset=shard_offset
                )

            loaded_weight_shard = loaded_weight.narrow(
                param.output_dim, shard_offset, shard_size
            )
925
926
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

927
928
929
930
    def weight_loader_v2(
        self,
        param: BasevLLMParameter,
        loaded_weight: torch.Tensor,
931
        loaded_shard_id: tuple[int, ...] | int | None = None,
932
    ):
933
        self.validate_shard_id(loaded_shard_id)
934
        if loaded_shard_id is None or isinstance(loaded_shard_id, tuple):
935
            if isinstance(param, PerTensorScaleParameter):
936
                param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0)
937
                return
938
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
939
                param.load_merged_column_weight(loaded_weight=loaded_weight)
940
                return
941
942
943
944
945
946
947
948
949
950
951
            output_sizes = (
                [self.output_sizes[idx] for idx in loaded_shard_id]
                if loaded_shard_id
                else None
            )
            if isinstance(param, BlockQuantScaleParameter):
                weight_block_size = getattr(self, "weight_block_size", None)
                output_sizes = [
                    adjust_block_scale_shard(weight_block_size, size, 0)[0]
                    for size in (output_sizes or self.output_sizes)
                ]
952
            # TODO: @dsikka - move to parameter.py
953
954
955
            self._load_fused_module_from_checkpoint(
                param, loaded_weight, output_sizes=output_sizes
            )
956
957
958
959
            return

        assert loaded_shard_id < len(self.output_sizes)

960
961
        shard_offset = sum(self.output_sizes[:loaded_shard_id])
        shard_size = self.output_sizes[loaded_shard_id]
962
963
        shard_offset //= self.tp_size
        shard_size //= self.tp_size
964

965
        if isinstance(param, BlockQuantScaleParameter):
966
967
968
            weight_block_size = getattr(self, "weight_block_size", None)
            shard_size, shard_offset = adjust_block_scale_shard(
                weight_block_size, shard_size, shard_offset
969
            )
970

971
972
973
974
975
976
977
        param.load_merged_column_weight(
            loaded_weight=loaded_weight,
            shard_id=loaded_shard_id,
            shard_offset=shard_offset,
            shard_size=shard_size,
            tp_rank=self.tp_rank,
        )
978

979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000

class QKVParallelLinear(ColumnParallelLinear):
    """Linear layers for the attention's QKV transformation.

    Linear layers for the linear transformation of the query, key, and value
    vectors in the attention layer. The weight matrix is concatenated along
    the output dimension. The layer is parallelized along the head dimension.
    When the number of key/value heads is smaller than the number of query
    heads (e.g., multi-query/grouped-query attention), the key/value head may
    be replicated while the query heads are partitioned.

    Args:
        hidden_size: input hidden state size of the transformer.
        head_size: size of each attention head.
        total_num_heads: total number of attention query heads.
        total_num_kv_heads: total number of attention key/value heads. If
                            None, assume total_num_kv_heads = total_num_heads.
        bias: If true, add bias.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
1001
        quant_config: Quantization configure.
1002
1003
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
1004
        return_bias: If true, return bias together with outputs in forward pass.
1005
        disable_tp: If true, weights matrix won't be sharded through tp rank.
1006
1007
    """

1008
1009
1010
1011
1012
    def __init__(
        self,
        hidden_size: int,
        head_size: int,
        total_num_heads: int,
1013
        total_num_kv_heads: int | None = None,
1014
1015
        bias: bool = True,
        skip_bias_add: bool = False,
1016
1017
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
1018
1019
1020
        prefix: str = "",
        *,
        return_bias: bool = True,
1021
        disable_tp: bool = False,
1022
        v_head_size: int | None = None,
1023
    ):
1024
1025
        self.hidden_size = hidden_size
        self.head_size = head_size
1026
        self.v_head_size = v_head_size if v_head_size is not None else head_size
1027
1028
1029
1030
1031
        self.total_num_heads = total_num_heads
        if total_num_kv_heads is None:
            total_num_kv_heads = total_num_heads
        self.total_num_kv_heads = total_num_kv_heads
        # Divide the weight matrix along the last dimension.
1032
        tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
1033
1034
1035
        self.num_heads = divide(self.total_num_heads, tp_size)
        if tp_size >= self.total_num_kv_heads:
            self.num_kv_heads = 1
1036
            self.num_kv_head_replicas = divide(tp_size, self.total_num_kv_heads)
1037
1038
1039
1040
        else:
            self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
            self.num_kv_head_replicas = 1
        input_size = self.hidden_size
1041
        output_size = (
1042
1043
1044
1045
            self.num_heads * self.head_size
            + self.num_kv_heads * self.head_size
            + self.num_kv_heads * self.v_head_size
        ) * tp_size
1046
1047
1048
        self.output_sizes = [
            self.num_heads * self.head_size * tp_size,  # q_proj
            self.num_kv_heads * self.head_size * tp_size,  # k_proj
1049
            self.num_kv_heads * self.v_head_size * tp_size,  # v_proj
James Fleming's avatar
James Fleming committed
1050
1051
        ]

1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
        super().__init__(
            input_size=input_size,
            output_size=output_size,
            bias=bias,
            gather_output=False,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            quant_config=quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
1064

1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
    def validate_shard_id(self, loaded_shard_id: str | None):
        if loaded_shard_id is None:
            return
        if isinstance(loaded_shard_id, str):
            if loaded_shard_id not in ["q", "k", "v"]:
                raise ValueError(
                    "Shard id for QKVParallelLinear should be 'q', 'k', or 'v', "
                    f"got shard id {loaded_shard_id}."
                )
            return
        raise ValueError("This line should not be reached")

1077
1078
1079
1080
1081
    def _get_shard_offset_mapping(self, loaded_shard_id: str):
        shard_offset_mapping = {
            "q": 0,
            "k": self.num_heads * self.head_size,
            "v": (self.num_heads + self.num_kv_heads) * self.head_size,
1082
1083
            "total": (self.num_heads + self.num_kv_heads) * self.head_size
            + self.num_kv_heads * self.v_head_size,
1084
1085
1086
1087
1088
1089
1090
        }
        return shard_offset_mapping.get(loaded_shard_id)

    def _get_shard_size_mapping(self, loaded_shard_id: str):
        shard_size_mapping = {
            "q": self.num_heads * self.head_size,
            "k": self.num_kv_heads * self.head_size,
1091
            "v": self.num_kv_heads * self.v_head_size,
1092
1093
1094
        }
        return shard_size_mapping.get(loaded_shard_id)

1095
1096
1097
    def _load_fused_module_from_checkpoint(
        self, param: BasevLLMParameter, loaded_weight: torch.Tensor
    ):
1098
        """
1099
        Handle special case for models where QKV layers are already
1100
        fused on disk. In this case, we have no shard id. This function
1101
        determines the shard id by splitting these layers and then calls
1102
1103
1104
1105
1106
1107
1108
1109
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """
        shard_offsets = [
            # (shard_id, shard_offset, shard_size)
            ("q", 0, self.total_num_heads * self.head_size),
1110
1111
1112
1113
1114
1115
1116
1117
            (
                "k",
                self.total_num_heads * self.head_size,
                self.total_num_kv_heads * self.head_size,
            ),
            (
                "v",
                (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
1118
                self.total_num_kv_heads * self.v_head_size,
1119
            ),
1120
1121
1122
1123
1124
1125
        ]

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
            if (
                isinstance(param, (PackedColumnParameter, PackedvLLMParameter))
                and param.packed_dim == param.output_dim
            ):
                shard_size, shard_offset = param.adjust_shard_indexes_for_packing(
                    shard_size=shard_size, shard_offset=shard_offset
                )

            loaded_weight_shard = loaded_weight.narrow(
                param.output_dim, shard_offset, shard_size
            )
1137
1138
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

1139
1140
1141
1142
    def weight_loader_v2(
        self,
        param: BasevLLMParameter,
        loaded_weight: torch.Tensor,
1143
        loaded_shard_id: str | None = None,
1144
    ):
1145
        self.validate_shard_id(loaded_shard_id)
1146
        if loaded_shard_id is None:  # special case for certain models
1147
            if isinstance(param, PerTensorScaleParameter):
1148
1149
1150
                param.load_qkv_weight(
                    loaded_weight=loaded_weight, shard_id=0, tp_rank=self.tp_rank
                )
1151
                return
1152
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
1153
                param.load_qkv_weight(loaded_weight=loaded_weight, tp_rank=self.tp_rank)
1154
                return
1155
            # TODO: @dsikka - move to parameter.py
1156
1157
1158
1159
1160
1161
1162
1163
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id in ["q", "k", "v"]

        shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
        shard_size = self._get_shard_size_mapping(loaded_shard_id)

1164
        if isinstance(param, BlockQuantScaleParameter):
1165
1166
1167
1168
            weight_block_size = getattr(self, "weight_block_size", None)
            shard_size, shard_offset = adjust_block_scale_shard(
                weight_block_size, shard_size, shard_offset
            )
1169

1170
1171
1172
1173
1174
1175
1176
1177
        param.load_qkv_weight(
            loaded_weight=loaded_weight,
            num_heads=self.num_kv_head_replicas,
            shard_id=loaded_shard_id,
            shard_offset=shard_offset,
            shard_size=shard_size,
            tp_rank=self.tp_rank,
        )
1178

1179
1180
1181
1182
    def weight_loader(
        self,
        param: Parameter,
        loaded_weight: torch.Tensor,
1183
        loaded_shard_id: str | None = None,
1184
    ):
1185
        self.validate_shard_id(loaded_shard_id)
1186
1187
1188
1189
        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
1190
        if is_gguf_weight_type:
1191
            idx_map = {"q": 0, "k": 1, "v": 2}
1192
1193
1194
1195
            if loaded_shard_id is not None:
                param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
1196
                param.shard_weight_type = {k: loaded_weight.item() for k in idx_map}
1197
1198
            return

1199
1200
        if is_gguf_weight:
            output_dim = getattr(param, "output_dim", None)
1201
1202
            shard_size = loaded_weight.size(output_dim) // self.tp_size
            start_idx = self.tp_rank * shard_size
1203

1204
            if loaded_shard_id is not None:
1205
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
1206
1207
1208
1209
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
1210

1211
1212
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
1213

1214
1215
        # Special case for per-tensor scales in fused case.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
1216

1217
        if loaded_shard_id is None:
1218
1219
            # Loaded weight is already fused on disk (qkv).
            # (e.g., Phi-3's qkv_proj).
1220
            if output_dim is None:
1221
                if needs_scalar_to_array:
1222
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
1223
1224
                        param_data, loaded_weight, 0
                    )
1225

1226
1227
1228
1229
1230
1231
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            shard_offsets = [
                # (shard_id, shard_offset, shard_size)
                ("q", 0, self.total_num_heads * self.head_size),
1232
1233
1234
1235
1236
1237
1238
1239
                (
                    "k",
                    self.total_num_heads * self.head_size,
                    self.total_num_kv_heads * self.head_size,
                ),
                (
                    "v",
                    (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
1240
                    self.total_num_kv_heads * self.v_head_size,
1241
                ),
1242
            ]
1243
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1244

1245
1246
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
1247
                # Special case for Quantized Weights.
1248
1249
1250
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
1251
1252
                    shard_size = shard_size // param.packed_factor
                    shard_offset = shard_offset // param.packed_factor
1253

1254
                    # Special case for Marlin.
1255
                    shard_size, shard_offset = adjust_marlin_shard(
1256
1257
                        param, shard_size, shard_offset
                    )
1258

1259
1260
1261
                if use_bitsandbytes_4bit:
                    orig_qkv_offsets = {
                        "q": (0, self.total_num_heads * self.head_size),
1262
1263
1264
1265
1266
1267
1268
                        "k": (
                            self.total_num_heads * self.head_size,
                            self.total_num_kv_heads * self.head_size,
                        ),
                        "v": (
                            (self.total_num_heads + self.total_num_kv_heads)
                            * self.head_size,
1269
                            self.total_num_kv_heads * self.v_head_size,
1270
1271
                        ),
                        "total": (
1272
1273
1274
                            (self.total_num_heads + self.total_num_kv_heads)
                            * self.head_size
                            + self.total_num_kv_heads * self.v_head_size,
1275
1276
                            0,
                        ),
1277
1278
1279
                    }

                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
1280
1281
                        param, orig_qkv_offsets, shard_id
                    )
1282

1283
                loaded_weight_shard = loaded_weight.narrow(
1284
1285
                    output_dim, shard_offset, shard_size
                )
1286
1287
1288
1289
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id in ["q", "k", "v"]
1290
1291

        # If output dim is defined, use the default loading process.
1292
1293
1294
1295
1296
1297
1298
1299
        if output_dim is not None:
            if loaded_shard_id == "q":
                shard_offset = 0
                shard_size = self.num_heads * self.head_size
            elif loaded_shard_id == "k":
                shard_offset = self.num_heads * self.head_size
                shard_size = self.num_kv_heads * self.head_size
            elif loaded_shard_id == "v":
1300
                shard_offset = (self.num_heads + self.num_kv_heads) * self.head_size
1301
                shard_size = self.num_kv_heads * self.v_head_size
1302
1303
1304
1305
1306
1307
1308

            if isinstance(param, BlockQuantScaleParameter):
                weight_block_size = getattr(self, "weight_block_size", None)
                shard_size, shard_offset = adjust_block_scale_shard(
                    weight_block_size, shard_size, shard_offset
                )

1309
            # Special case for Quantized Weights.
1310
1311
1312
1313
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
1314
1315
                shard_size = shard_size // param.packed_factor
                shard_offset = shard_offset // param.packed_factor
1316

1317
                # Special case for Marlin.
1318
                shard_size, shard_offset = adjust_marlin_shard(
1319
1320
                    param, shard_size, shard_offset
                )
1321

1322
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1323
1324
1325
1326
1327
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

1328
            if use_bitsandbytes_4bit:
1329
1330
                orig_qkv_offsets = {
                    "q": (0, self.num_heads * self.head_size),
1331
1332
1333
1334
1335
1336
                    "k": (
                        self.num_heads * self.head_size,
                        self.num_kv_heads * self.head_size,
                    ),
                    "v": (
                        (self.num_heads + self.num_kv_heads) * self.head_size,
1337
                        self.num_kv_heads * self.v_head_size,
1338
1339
                    ),
                    "total": (
1340
1341
                        (self.num_heads + self.num_kv_heads) * self.head_size
                        + self.num_kv_heads * self.v_head_size,
1342
1343
                        0,
                    ),
1344
                }
1345
                shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
1346
1347
                    param, orig_qkv_offsets, loaded_shard_id
                )
1348

1349
            param_data = param_data.narrow(output_dim, shard_offset, shard_size)
1350
            if loaded_shard_id == "q":
1351
                shard_rank = self.tp_rank
1352
            else:
1353
1354
                shard_rank = self.tp_rank // self.num_kv_head_replicas
            start_idx = shard_rank * shard_size
1355

1356
            if not is_sharded_weight:
1357
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
1358

1359
1360
1361
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
1362
1363
                param_data, loaded_weight, loaded_shard_id
            )
1364
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
1365
1366
1367
1368
1369
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "QKVParallelLinear, assume the weight is the same "
1370
1371
                    "for all partitions."
                )
1372

1373
1374
1375
1376
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)


1377
# --8<-- [start:row_parallel_linear]
1378
@PluggableLayer.register("row_parallel_linear")
1379
class RowParallelLinear(LinearBase):
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
    """Linear layer with row parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its first dimension and X along its second dimension as:
               -   -
              | A_1 |
              | .   |
          A = | .   |        X = [X_1, ..., X_p]
              | .   |
              | A_p |
               -   -
    Arguments:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias. Note that bias is not parallelized.
        input_is_parallel: If true, we assume that the input is already
                           split across the GPUs and we do not split
                           again.
        skip_bias_add: This was added to enable performance optimization where
                       bias can be fused with other element-wise operations.
                       We skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
1402
1403
1404
        reduce_results: If true, call all-reduce on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y = X_iA_i
1405
        quant_config: Quantization configure.
1406
1407
1408
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.down_proj)
        return_bias: If true, return bias together with outputs in forward pass.
1409
        disable_tp: If true, weights matrix won't be sharded through tp rank.
1410
1411
    """

1412
1413
    # --8<-- [end:row_parallel_linear]

1414
1415
1416
1417
1418
1419
1420
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        input_is_parallel: bool = True,
        skip_bias_add: bool = False,
1421
        params_dtype: torch.dtype | None = None,
1422
        reduce_results: bool = True,
1423
        quant_config: QuantizationConfig | None = None,
1424
1425
1426
        prefix: str = "",
        *,
        return_bias: bool = True,
1427
        disable_tp: bool = False,
1428
    ):
1429
        # Divide the weight matrix along the first dimension.
1430
1431
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
1432
1433
1434
1435
        self.input_size_per_partition = divide(input_size, self.tp_size)
        self.output_size_per_partition = output_size
        self.output_partition_sizes = [output_size]

1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
        super().__init__(
            input_size,
            output_size,
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
1446

1447
1448
1449
        self.input_is_parallel = input_is_parallel
        self.reduce_results = reduce_results

1450
        assert self.quant_method is not None
1451
1452
1453
        self.quant_method.create_weights(
            layer=self,
            input_size_per_partition=self.input_size_per_partition,
1454
            output_partition_sizes=self.output_partition_sizes,
1455
1456
1457
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
1458
            weight_loader=(
1459
1460
1461
1462
1463
                self.weight_loader_v2
                if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
                else self.weight_loader
            ),
        )
1464
        if not reduce_results and (bias and not skip_bias_add):
1465
1466
1467
1468
            raise ValueError(
                "When not reduce the results, adding bias to the "
                "results can lead to incorrect results"
            )
1469
1470

        if bias:
1471
1472
1473
1474
1475
1476
1477
1478
            self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype))
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
1479
1480
        else:
            self.register_parameter("bias", None)
1481
        self.update_param_tp_status()
1482
1483
1484

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        input_dim = getattr(param, "input_dim", None)
1485
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1486
1487
1488
1489
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
1490
1491
1492
1493
1494
1495
1496
1497
1498

        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
1499
1500
            weight_shape = list(loaded_weight.shape)
            if input_dim:
1501
                weight_shape[input_dim] = weight_shape[input_dim] // self.tp_size
1502
            param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
1503

1504
        param_data = param.data
1505
        if input_dim is not None and not is_sharded_weight:
1506
            shard_size = param_data.shape[input_dim]
1507
            start_idx = self.tp_rank * shard_size
1508
            loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)
1509

1510
1511
1512
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
1513
1514
            loaded_weight = loaded_weight.reshape(1)

1515
1516
1517
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

1518
    def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor):
1519
1520
1521
1522
1523
1524
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)

1525
1526
        param.load_row_parallel_weight(loaded_weight=loaded_weight)

1527
    def forward(
1528
1529
        self,
        input_,
1530
    ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
1531
1532
1533
1534
        if self.input_is_parallel:
            input_parallel = input_
        else:
            splitted_input = split_tensor_along_last_dim(
1535
1536
                input_, num_partitions=self.tp_size
            )
1537
            input_parallel = splitted_input[self.tp_rank].contiguous()
1538
1539

        # Matrix multiply.
1540
        assert self.quant_method is not None
1541
1542
1543
        # Only fuse bias add into GEMM for rank 0 (this ensures that
        # bias will not get added more than once in TP>1 case)
        bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
1544
1545
        output_parallel = self.quant_method.apply(self, input_parallel, bias_)

1546
        if self.reduce_results and self.tp_size > 1:
1547
            output = tensor_model_parallel_all_reduce(output_parallel)
1548
        else:
1549
1550
            output = output_parallel

1551
1552
        if not self.return_bias:
            return output
1553
        output_bias = self.bias if self.skip_bias_add else None
1554
        return output, output_bias
1555
1556

    def extra_repr(self) -> str:
1557
        s = f"in_features={self.input_size_per_partition}"
1558
1559
1560
1561
1562
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={self.tp_size}"
        s += f", reduce_results={self.reduce_results}"
        return s