linear.py 60 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
from abc import abstractmethod
6
from typing import Any
7
8

import torch
9
from torch.nn.parameter import Parameter, UninitializedParameter
10

11
12
13
14
15
16
17
18
from vllm.distributed import (
    divide,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
    split_tensor_along_last_dim,
    tensor_model_parallel_all_gather,
    tensor_model_parallel_all_reduce,
)
19
from vllm.logger import init_logger
20
from vllm.model_executor.custom_op import PluggableLayer
21
22
23
24
from vllm.model_executor.layers.batch_invariant import (
    linear_batch_invariant,
    vllm_is_batch_invariant,
)
25
from vllm.model_executor.layers.quantization.base_config import (
26
27
28
    QuantizationConfig,
    QuantizeMethodBase,
)
29
30
31
32
from vllm.model_executor.layers.utils import (
    dispatch_unquantized_gemm,
    is_layer_moe_router_gate,
)
33
34
35
36
37
38
39
40
41
from vllm.model_executor.parameter import (
    BasevLLMParameter,
    BlockQuantScaleParameter,
    ModelWeightParameter,
    PackedColumnParameter,
    PackedvLLMParameter,
    PerTensorScaleParameter,
    RowvLLMParameter,
)
42
from vllm.model_executor.utils import set_weight_attrs
43
from vllm.platforms import current_platform
44
45
46

logger = init_logger(__name__)

47
WEIGHT_LOADER_V2_SUPPORTED = [
48
    "UnquantizedLinearMethod",
49
    "CompressedTensorsLinearMethod",
50
    "CompressedTensorsLinearTransformMethod",
51
52
53
54
55
56
57
58
59
60
    "AWQMarlinLinearMethod",
    "AWQLinearMethod",
    "GPTQMarlinLinearMethod",
    "Fp8LinearMethod",
    "MarlinLinearMethod",
    "GPTQMarlin24LinearMethod",
    "TPUInt8LinearMethod",
    "GPTQLinearMethod",
    "FBGEMMFp8LinearMethod",
    "ModelOptFp8LinearMethod",
61
62
    "ModelOptFp8PcPtLinearMethod",
    "ModelOptFp8PbWoLinearMethod",
63
64
    "QuarkLinearMethod",
    "ModelOptNvFp4LinearMethod",
65
    "PetitNvFp4LinearMethod",
66
]
67

68

69
70
71
72
73
74
def adjust_marlin_shard(
    param: Parameter,
    shard_size: int,
    shard_offset: int,
) -> tuple[int, int]:
    marlin_tile_size: int | None = getattr(param, "marlin_tile_size", None)
75
76
77
78
79
80
    if marlin_tile_size is None:
        return shard_size, shard_offset

    return shard_size * marlin_tile_size, shard_offset * marlin_tile_size


81
82
83
84
85
def adjust_block_scale_shard(
    weight_block_size: tuple[int, ...] | None,
    shard_size: int,
    shard_offset: int,
) -> tuple[int, int]:
86
87
88
89
90
91
92
    assert weight_block_size is not None
    block_n = weight_block_size[0]
    shard_offset = (shard_offset + block_n - 1) // block_n
    shard_size = (shard_size + block_n - 1) // block_n
    return shard_size, shard_offset


93
def adjust_bitsandbytes_4bit_shard(
94
95
96
    param: Parameter,
    shard_offsets: dict[str, tuple[int, int]],
    loaded_shard_id: str,
97
) -> tuple[int, int]:
98
99
    """Adjust the quantization offsets and sizes for BitsAndBytes sharding."""

100
101
    total, _ = shard_offsets["total"]
    orig_offset, orig_size = shard_offsets[loaded_shard_id]
102
103
104
105
106
107
108
109

    quantized_total = param.data.shape[0]
    quantized_offset = orig_offset * quantized_total // total
    quantized_size = orig_size * quantized_total // total

    return quantized_size, quantized_offset


110
111
112
113
114
def adjust_scalar_to_fused_array(
    param_data: torch.Tensor,
    loaded_weight: torch.Tensor,
    shard_id: int | str,
) -> tuple[torch.Tensor, torch.Tensor]:
115
116
    """For fused modules (QKV and MLP) we have an array of length
    N that holds 1 scale for each "logical" matrix. So the param
117
118
    is an array of length N. The loaded_weight corresponds to
    one of the shards on disk. Here, we slice the param based on
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
    the shard_id for loading.
    """
    qkv_idxs = {"q": 0, "k": 1, "v": 2}

    if isinstance(shard_id, str):
        shard_id = qkv_idxs[shard_id]
    elif not isinstance(shard_id, int):
        raise ValueError(f"Unknown Shard Id {shard_id}")

    # AutoFP8 scales do not have a shape
    # compressed-tensors scales do have a shape
    if len(loaded_weight.shape) != 0:
        assert loaded_weight.shape[0] == 1
        loaded_weight = loaded_weight[0]

134
    return param_data[shard_id], loaded_weight
135
136


137
138
# TODO(Isotr0py): We might need a more flexible structure to handle
# bitsandbytes shard offsets.
139
140
141
def left_shift_bitsandbytes_4bit_shard(
    bnb_weight_attrs: dict[str, Any],
) -> tuple[dict[str, Any], dict[str, Any]]:
142
143
144
145
146
    """
    Separate the BitsAndBytes 4-bit shard.

    For example, given bnb weight attributes as below:
    {
147
        'bnb_shard_offsets': array([0, 4, 8, 16]),
148
149
150
151
152
        'bnb_quant_state': {0: ..., 1: ..., 2: ...},
    }

    The function will return:
    {
153
        'bnb_shard_offsets': array([0, 4]),
154
155
156
157
158
159
160
161
162
163
164
165
166
167
        'bnb_quant_state': {0: ...},
    }
    and
    {
        'bnb_shard_offsets': array([0, 4, 12]),
        'bnb_quant_state': {0: ..., 1: ...},
    }
    """
    shard_offsets = bnb_weight_attrs["bnb_shard_offsets"]
    offset_l = shard_offsets[:2]
    offset_r = shard_offsets[1:] - shard_offsets[1]
    quant_state_l = {0: bnb_weight_attrs["bnb_quant_state"][0]}
    quant_state_r = {
        i - 1: bnb_weight_attrs["bnb_quant_state"][i]
168
        for i in range(1, len(shard_offsets) - 1)
169
170
171
172
173
174
    }
    left = dict(bnb_shard_offsets=offset_l, bnb_quant_state=quant_state_l)
    right = dict(bnb_shard_offsets=offset_r, bnb_quant_state=quant_state_r)
    return left, right


175
class LinearMethodBase(QuantizeMethodBase):
176
177
178
    """Base class for different (maybe quantized) linear methods."""

    @abstractmethod
179
180
181
182
183
184
185
186
187
188
189
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        """Create weights for a linear layer.
190
           The weights will be set as attributes of the layer.
191

192
193
194
        Args:
            layer: The layer that is using the LinearMethodBase factory.
            input_size_per_partition: Size of the weight input dim on rank X.
195
            output_partition_sizes: Sizes of the output dim of each logical
196
197
198
199
200
201
                weight on rank X. E.g., output_partition_sizes for QKVLinear
                is a list contains the width of Wq, Wk, Wv on rank X.
            input_size: Size of the input dim of the weight across all ranks.
            output_size: Size of the output dim of the weight across all ranks.
            params_dtype: Datatype of the parameters.
        """
202
203
204
        raise NotImplementedError

    @abstractmethod
205
206
207
208
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
209
        bias: torch.Tensor | None = None,
210
    ) -> torch.Tensor:
211
212
        """Apply the weights in layer to the input tensor.
        Expects create_weights to have been called before on the layer."""
213
214
215
216
        raise NotImplementedError


class UnquantizedLinearMethod(LinearMethodBase):
217
    """Linear method without quantization."""
218

219
220
221
222
223
224
225
226
227
228
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
229
230
231
232
        # This method creates unquantized linear weights.
        # The weights are not quantized, and they are not sharded.
        # The amount of memory allocated for the weights is
        # sum(output_partition_sizes) * input_size_per_partition.
233
234
235
236
237
238
239
240
241
242
243
        weight_loader = extra_weight_attrs.pop("weight_loader")
        weight = ModelWeightParameter(
            data=torch.empty(
                sum(output_partition_sizes),
                input_size_per_partition,
                dtype=params_dtype,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
244

245
246
        layer.register_parameter("weight", weight)
        set_weight_attrs(weight, extra_weight_attrs)
247

248
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
249
        if current_platform.is_cpu():
250
            from vllm.model_executor.layers.utils import dispatch_cpu_unquantized_gemm
251

252
            dispatch_cpu_unquantized_gemm(layer, remove_weight=True)
253

254
255
256
257
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
258
        bias: torch.Tensor | None = None,
259
    ) -> torch.Tensor:
260
261
262
263
264
265
        if (
            vllm_is_batch_invariant()
            and current_platform.is_cuda_alike()
            and is_layer_moe_router_gate(getattr(layer, "prefix", ""))
        ):
            return linear_batch_invariant(x, layer.weight, bias)
266
        return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
267
268


269
class LinearBase(PluggableLayer):
270
    """Base linear layer.
271
272
273
274
275
276

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
277
        quant_config: Quantization configure.
278
        prefix: Prefix for parameter names.
279
        return_bias: If true, return bias together with outputs in forward pass.
280
        disable_tp: If true, tensor parallelism will be disabled for this layer.
281
282
283
284
285
286
287
    """

    def __init__(
        self,
        input_size: int,
        output_size: int,
        skip_bias_add: bool = False,
288
289
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
290
        prefix: str = "",
291
292
        *,
        return_bias: bool = True,
293
        disable_tp: bool = False,
294
295
296
297
298
299
300
301
302
303
    ):
        super().__init__()

        # Keep input parameters
        self.input_size = input_size
        self.output_size = output_size
        self.skip_bias_add = skip_bias_add
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype
304
305
        self.quant_config = quant_config
        self.prefix = prefix
306
        self.allow_fp8_block_shape_mismatch = False
307
        if quant_config is None:
308
            self.quant_method: QuantizeMethodBase | None = UnquantizedLinearMethod()
309
        else:
310
            self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
311
        self.return_bias = return_bias
312
        self.disable_tp = disable_tp
313
314
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
315

316
    def update_param_tp_status(self):
317
318
319
320
        for param in self.parameters():
            if isinstance(param, BasevLLMParameter):
                param.tp_rank = self.tp_rank
                param.tp_size = self.tp_size
321
322


323
# --8<-- [start:replicated_linear]
324
@PluggableLayer.register("replicated_linear")
325
326
327
328
329
330
331
332
333
334
class ReplicatedLinear(LinearBase):
    """Replicated linear layer.

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        bias: If true, add bias.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
        quant_config: Quantization configure.
335
336
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
337
        return_bias: If true, return bias together with outputs in forward pass.
338
        disable_tp: Take no effect for replicated linear layers.
339
340
    """

341
342
    # --8<-- [end:replicated_linear]

343
344
345
346
347
348
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        skip_bias_add: bool = False,
349
350
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
351
352
353
        prefix: str = "",
        *,
        return_bias: bool = True,
354
        disable_tp: bool = False,
355
    ):
356
357
358
359
360
361
        # If MergedReplicatedLinear, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = self.output_sizes
        else:
            self.output_partition_sizes = [output_size]

362
363
364
365
366
367
368
369
370
371
        super().__init__(
            input_size,
            output_size,
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
372

373
374
        # All the linear layer supports quant method.
        assert self.quant_method is not None
375
376
377
378
379
380
381
382
383
        self.quant_method.create_weights(
            self,
            self.input_size,
            self.output_partition_sizes,
            self.input_size,
            self.output_size,
            self.params_dtype,
            weight_loader=self.weight_loader,
        )
384

385
386
        if bias:
            self.bias = Parameter(
387
388
389
390
391
392
393
394
395
                torch.empty(self.output_size, dtype=self.params_dtype)
            )
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
396
397
398
        else:
            self.register_parameter("bias", None)

399
400
401
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        # If the weight on disk does not have a shape, give it one
        # (such scales for AutoFp8).
402
403
404
405
406
407
408
409
410
411
412
        # Special case for GGUF

        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
            param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)

413
414
415
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)

416
417
        assert param.size() == loaded_weight.size(), (
            f"Tried to load weights of size {loaded_weight.size()}"
418
419
            f"to a parameter of size {param.size()}"
        )
420
421
        param.data.copy_(loaded_weight)

422
    def forward(
423
424
        self,
        x: torch.Tensor,
425
    ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
426
        bias = self.bias if not self.skip_bias_add else None
427
        assert self.quant_method is not None
428

429
        output = self.quant_method.apply(self, x, bias)
430

431
432
        if not self.return_bias:
            return output
433
        output_bias = self.bias if self.skip_bias_add else None
434
435
        return output, output_bias

436
437
438
439
440
441
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        return s

442

443
# --8<-- [start:column_parallel_linear]
444
@PluggableLayer.register("column_parallel_linear")
445
class ColumnParallelLinear(LinearBase):
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
    """Linear layer with column parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its second dimension as A = [A_1, ..., A_p].

    Args:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y_i = XA_i
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
462
        quant_config: Quantization configure.
463
        prefix: The name of the layer in the state dict, including all parents
464
                        (e.g. model.layers.0.qkv_proj)
465
466
        return_bias: If true, return bias together with outputs in forward pass.
        disable_tp: If true, weights matrix won't be sharded through tp rank.
467
468
    """

469
470
    # --8<-- [end:column_parallel_linear]

471
472
473
474
475
476
477
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
478
479
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
480
481
482
        prefix: str = "",
        *,
        return_bias: bool = True,
483
        disable_tp: bool = False,
484
    ):
485
        # Divide the weight matrix along the last dimension.
486
487
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
488
489
        self.input_size_per_partition = input_size
        self.output_size_per_partition = divide(output_size, self.tp_size)
490
491
492
493
        self.output_partition_sizes = [self.output_size_per_partition]
        # If QKV or MergedColumn, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = [
494
                divide(output_size, self.tp_size) for output_size in self.output_sizes
495
496
            ]

497
498
499
500
501
502
503
504
505
506
        super().__init__(
            input_size,
            output_size,
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
507

508
        self._maybe_allow_fp8_block_shape_mismatch()
509
510
511
        self.gather_output = gather_output

        assert self.quant_method is not None
512
513
        self.quant_method.create_weights(
            layer=self,
514
            input_size_per_partition=self.input_size_per_partition,
515
516
517
518
            output_partition_sizes=self.output_partition_sizes,
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
519
            weight_loader=(
520
521
522
523
524
                self.weight_loader_v2
                if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
                else self.weight_loader
            ),
        )
525
526
        if bias:
            self.bias = Parameter(
527
528
529
530
531
532
533
534
535
                torch.empty(self.output_size_per_partition, dtype=params_dtype)
            )
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
536
537
        else:
            self.register_parameter("bias", None)
538
        self.update_param_tp_status()
539

540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
    def _maybe_allow_fp8_block_shape_mismatch(self) -> None:
        quant_config = getattr(self, "quant_config", None)
        weight_block = getattr(quant_config, "weight_block_size", None)
        if (
            weight_block is None
            or len(weight_block) < 1
            or len(self.output_partition_sizes) <= 1
        ):
            return

        try:
            block_n = int(weight_block[0])
        except (ValueError, TypeError):
            return

        if block_n <= 0:
            return

        if any(size % block_n != 0 for size in self.output_partition_sizes):
            self.allow_fp8_block_shape_mismatch = True
            logger.debug(
                "Allowing FP8 block shape mismatch for %s (block_n=%d, partitions=%s)",
                getattr(self, "prefix", "<unknown>"),
                block_n,
                self.output_partition_sizes,
            )

567
568
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        output_dim = getattr(param, "output_dim", None)
569

570
571
572
573
574
575
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

576
577
578
579
580
581
582
583
        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
584
585
            final_shape = list(loaded_weight.shape)
            if output_dim is not None:
586
                assert final_shape[output_dim] % self.tp_size == 0
587
                final_shape[output_dim] = final_shape[output_dim] // self.tp_size
588
            param.materialize(final_shape, dtype=loaded_weight.dtype)
589

590
        param_data = param.data
591
        if output_dim is not None and not is_sharded_weight:
592
            shard_size = param_data.shape[output_dim]
593
            start_idx = self.tp_rank * shard_size
594
            loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
595
596
597
598
599

        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)
600

601
602
603
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

604
    def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor):
605
606
607
608
609
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)
610
611
        param.load_column_parallel_weight(loaded_weight=loaded_weight)

612
    def forward(
613
614
        self,
        input_,
615
    ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
616
617
618
        bias = self.bias if not self.skip_bias_add else None

        # Matrix multiply.
619
        assert self.quant_method is not None
620
        output_parallel = self.quant_method.apply(self, input_, bias)
621

622
        if self.gather_output and self.tp_size > 1:
623
624
625
626
            # All-gather across the partitions.
            output = tensor_model_parallel_all_gather(output_parallel)
        else:
            output = output_parallel
627

628
629
        if not self.return_bias:
            return output
630
        output_bias = self.bias if self.skip_bias_add else None
631
632
        return output, output_bias

633
634
635
636
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size_per_partition}"
        s += f", bias={self.bias is not None}"
637
        s += f", tp_size={self.tp_size}"
638
639
640
        s += f", gather_output={self.gather_output}"
        return s

641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659

class MergedColumnParallelLinear(ColumnParallelLinear):
    """Packed linear layers with column parallelism.

    Similar to ColumnParallelLinear, but the weight matrix is concatenated
    along the output dimension. When the weight matrix is loaded, the
    different partitions are sharded separately.

    Args:
        input_size: input dimension of the linear layer.
        output_sizes: list of output dimensions of the linear layer.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make the output
                       available to all GPUs, otherwise, every GPU will have
                       its own output.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
660
        quant_config: Quantization configure.
661
662
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
663
        return_bias: If true, return bias together with outputs in forward pass.
664
665
        disable_tp: If true, all weights matrix won't be sharded, this layer
                    will be treated as a "Replicated" MergedLinear.
666
667
    """

668
669
670
671
672
673
674
    def __init__(
        self,
        input_size: int,
        output_sizes: list[int],
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
675
676
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
677
678
679
        prefix: str = "",
        *,
        return_bias: bool = True,
680
        disable_tp: bool = False,
681
    ):
682
        self.output_sizes = output_sizes
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0

        assert all(output_size % self.tp_size == 0 for output_size in output_sizes)
        super().__init__(
            input_size=input_size,
            output_size=sum(output_sizes),
            bias=bias,
            gather_output=gather_output,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            quant_config=quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
James Fleming's avatar
James Fleming committed
699

700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
    def validate_shard_id(self, loaded_shard_id: int | tuple[int, ...] | None):
        if loaded_shard_id is None:
            return
        if isinstance(loaded_shard_id, tuple):
            for idx in loaded_shard_id:
                if not (0 <= idx < len(self.output_sizes)):
                    raise ValueError(
                        f"Shard id index {idx} should be between 0 and "
                        f"{len(self.output_sizes) - 1}. Got shard id {loaded_shard_id}."
                    )
            if len(loaded_shard_id) > 1 and any(
                b - a != 1 for a, b in zip(loaded_shard_id[:-1], loaded_shard_id[1:])
            ):
                raise ValueError(
                    "Shard id with multiple indices should be consecutive. "
                    f"Got shard id {loaded_shard_id}."
                )
            return
        elif isinstance(loaded_shard_id, int):
            if loaded_shard_id < 0 or loaded_shard_id >= len(self.output_sizes):
                raise ValueError(
                    f"Shard id should be between 0 and {len(self.output_sizes) - 1}. "
                    f"Got shard id {loaded_shard_id}."
                )
            return
        raise ValueError("This line should not be reached")

727
728
729
730
    def weight_loader(
        self,
        param: Parameter,
        loaded_weight: torch.Tensor,
731
        loaded_shard_id: tuple[int, ...] | int | None = None,
732
    ):
733
734
        self.validate_shard_id(loaded_shard_id)
        # FIXME(Isotr0py): Enable tuple shard_id for BNB quantization.
735
736
737
738
739
        if isinstance(loaded_shard_id, tuple):
            raise NotImplementedError(
                "Shard id with multiple indices is not supported in weight_loader, "
                "please use weight_loader_v2 instead."
            )
740
741
742
743
744
        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
745
746
747
748
749
            if loaded_shard_id is not None:
                param.data[loaded_shard_id].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
                param.shard_weight_type = {
750
                    i: loaded_weight.item() for i, _ in enumerate(self.output_sizes)
751
                }
752
753
            return

754
755
        if is_gguf_weight:
            output_dim = getattr(param, "output_dim", None)
756
757
            shard_size = loaded_weight.size(output_dim) // self.tp_size
            start_idx = self.tp_rank * shard_size
758

759
            if loaded_shard_id is not None:
760
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
761
762
763
764
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
765

766
767
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
768
769
        # Special case for per-tensor scale to load scalar into fused array.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
770

771
        if loaded_shard_id is None:
772
773
            # Loaded weight is already fused on disk (mlp).
            # (e.g., Phi-3's gate_up_proj).
774
            if output_dim is None:
775
                if needs_scalar_to_array:
776
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
777
778
                        param_data, loaded_weight, 0
                    )
779

780
781
782
783
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            current_shard_offset = 0
784
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
785
            shard_offsets: list[tuple[int, int, int]] = []
786
787
788
789
790
            for i, output_size in enumerate(self.output_sizes):
                shard_offsets.append((i, current_shard_offset, output_size))
                current_shard_offset += output_size
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
791
                # Special case for Quantization.
792
793
794
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
795
796
                    shard_size = shard_size // param.packed_factor
                    shard_offset = shard_offset // param.packed_factor
797
                    # Special case for Marlin.
798
                    shard_size, shard_offset = adjust_marlin_shard(
799
800
                        param, shard_size, shard_offset
                    )
801

802
                if use_bitsandbytes_4bit:
803
804
805
806
807
808
809
                    index = list(itertools.accumulate([0] + self.output_sizes))
                    orig_offsets = {
                        str(i): (index[i], size)
                        for i, size in enumerate(self.output_sizes)
                    }
                    orig_offsets["total"] = (self.output_size, 0)
                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
810
811
                        param, orig_offsets, str(shard_id)
                    )
812

813
                loaded_weight_shard = loaded_weight.narrow(
814
815
                    output_dim, shard_offset, shard_size
                )
816
817
818
819
820
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id < len(self.output_sizes)
        if output_dim is not None:
821
822
            shard_offset = sum(self.output_sizes[:loaded_shard_id])
            shard_size = self.output_sizes[loaded_shard_id]
823
824
            shard_offset //= self.tp_size
            shard_size //= self.tp_size
825
826
827
828
829
830
831

            if isinstance(param, BlockQuantScaleParameter):
                weight_block_size = getattr(self, "weight_block_size", None)
                shard_size, shard_offset = adjust_block_scale_shard(
                    weight_block_size, shard_size, shard_offset
                )

832
            # Special case for quantization.
833
834
835
836
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
837
838
                shard_size = shard_size // param.packed_factor
                shard_offset = shard_offset // param.packed_factor
839
                # Special case for Marlin.
840
                shard_size, shard_offset = adjust_marlin_shard(
841
842
                    param, shard_size, shard_offset
                )
843

844
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
845
846
847
848
849
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

850
            if use_bitsandbytes_4bit:
851
                shard_size = loaded_weight.shape[output_dim]
852
                shard_offset = loaded_weight.shape[output_dim] * loaded_shard_id
853

854
            param_data = param_data.narrow(output_dim, shard_offset, shard_size)
855
            start_idx = self.tp_rank * shard_size
856
            if not is_sharded_weight:
857
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
858
859
860
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
861
862
                param_data, loaded_weight, loaded_shard_id
            )
863

864
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
865
866
867
868
869
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "MergedColumnParallelLinear, assume the weight is "
870
871
                    "the same for all partitions."
                )
872

873
874
875
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

876
    def _load_fused_module_from_checkpoint(
877
878
879
880
        self,
        param: BasevLLMParameter,
        loaded_weight: torch.Tensor,
        output_sizes: list[int] | None = None,
881
    ):
882
883
884
        """
        Handle special case for models where MLP layers are already
        fused on disk. In this case, we have no shard id. This function
885
        determines the shard id by splitting these layers and then calls
886
887
888
889
890
891
892
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """

        current_shard_offset = 0
893
        shard_offsets: list[tuple[int, int, int]] = []
894
895
        output_sizes = output_sizes or self.output_sizes
        for i, output_size in enumerate(output_sizes):
896
897
898
899
900
901
902
            shard_offsets.append((i, current_shard_offset, output_size))
            current_shard_offset += output_size

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
903
904
905
906
907
908
909
910
911
912
913
            if (
                isinstance(param, (PackedColumnParameter, PackedvLLMParameter))
                and param.packed_dim == param.output_dim
            ):
                shard_size, shard_offset = param.adjust_shard_indexes_for_packing(
                    shard_size=shard_size, shard_offset=shard_offset
                )

            loaded_weight_shard = loaded_weight.narrow(
                param.output_dim, shard_offset, shard_size
            )
914
915
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

916
917
918
919
    def weight_loader_v2(
        self,
        param: BasevLLMParameter,
        loaded_weight: torch.Tensor,
920
        loaded_shard_id: tuple[int, ...] | int | None = None,
921
    ):
922
        self.validate_shard_id(loaded_shard_id)
923
        if loaded_shard_id is None or isinstance(loaded_shard_id, tuple):
924
            if isinstance(param, PerTensorScaleParameter):
925
                param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0)
926
                return
927
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
928
                param.load_merged_column_weight(loaded_weight=loaded_weight)
929
                return
930
931
932
933
934
935
936
937
938
939
940
            output_sizes = (
                [self.output_sizes[idx] for idx in loaded_shard_id]
                if loaded_shard_id
                else None
            )
            if isinstance(param, BlockQuantScaleParameter):
                weight_block_size = getattr(self, "weight_block_size", None)
                output_sizes = [
                    adjust_block_scale_shard(weight_block_size, size, 0)[0]
                    for size in (output_sizes or self.output_sizes)
                ]
941
            # TODO: @dsikka - move to parameter.py
942
943
944
            self._load_fused_module_from_checkpoint(
                param, loaded_weight, output_sizes=output_sizes
            )
945
946
947
948
            return

        assert loaded_shard_id < len(self.output_sizes)

949
950
        shard_offset = sum(self.output_sizes[:loaded_shard_id])
        shard_size = self.output_sizes[loaded_shard_id]
951
952
        shard_offset //= self.tp_size
        shard_size //= self.tp_size
953

954
        if isinstance(param, BlockQuantScaleParameter):
955
956
957
            weight_block_size = getattr(self, "weight_block_size", None)
            shard_size, shard_offset = adjust_block_scale_shard(
                weight_block_size, shard_size, shard_offset
958
            )
959

960
961
962
963
964
965
966
        param.load_merged_column_weight(
            loaded_weight=loaded_weight,
            shard_id=loaded_shard_id,
            shard_offset=shard_offset,
            shard_size=shard_size,
            tp_rank=self.tp_rank,
        )
967

968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989

class QKVParallelLinear(ColumnParallelLinear):
    """Linear layers for the attention's QKV transformation.

    Linear layers for the linear transformation of the query, key, and value
    vectors in the attention layer. The weight matrix is concatenated along
    the output dimension. The layer is parallelized along the head dimension.
    When the number of key/value heads is smaller than the number of query
    heads (e.g., multi-query/grouped-query attention), the key/value head may
    be replicated while the query heads are partitioned.

    Args:
        hidden_size: input hidden state size of the transformer.
        head_size: size of each attention head.
        total_num_heads: total number of attention query heads.
        total_num_kv_heads: total number of attention key/value heads. If
                            None, assume total_num_kv_heads = total_num_heads.
        bias: If true, add bias.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
990
        quant_config: Quantization configure.
991
992
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
993
        return_bias: If true, return bias together with outputs in forward pass.
994
        disable_tp: If true, weights matrix won't be sharded through tp rank.
995
996
    """

997
998
999
1000
1001
    def __init__(
        self,
        hidden_size: int,
        head_size: int,
        total_num_heads: int,
1002
        total_num_kv_heads: int | None = None,
1003
1004
        bias: bool = True,
        skip_bias_add: bool = False,
1005
1006
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
1007
1008
1009
        prefix: str = "",
        *,
        return_bias: bool = True,
1010
        disable_tp: bool = False,
1011
        v_head_size: int | None = None,
1012
    ):
1013
1014
        self.hidden_size = hidden_size
        self.head_size = head_size
1015
        self.v_head_size = v_head_size if v_head_size is not None else head_size
1016
1017
1018
1019
1020
        self.total_num_heads = total_num_heads
        if total_num_kv_heads is None:
            total_num_kv_heads = total_num_heads
        self.total_num_kv_heads = total_num_kv_heads
        # Divide the weight matrix along the last dimension.
1021
        tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
1022
1023
1024
        self.num_heads = divide(self.total_num_heads, tp_size)
        if tp_size >= self.total_num_kv_heads:
            self.num_kv_heads = 1
1025
            self.num_kv_head_replicas = divide(tp_size, self.total_num_kv_heads)
1026
1027
1028
1029
        else:
            self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
            self.num_kv_head_replicas = 1
        input_size = self.hidden_size
1030
        output_size = (
1031
1032
1033
1034
            self.num_heads * self.head_size
            + self.num_kv_heads * self.head_size
            + self.num_kv_heads * self.v_head_size
        ) * tp_size
1035
1036
1037
        self.output_sizes = [
            self.num_heads * self.head_size * tp_size,  # q_proj
            self.num_kv_heads * self.head_size * tp_size,  # k_proj
1038
            self.num_kv_heads * self.v_head_size * tp_size,  # v_proj
James Fleming's avatar
James Fleming committed
1039
1040
        ]

1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
        super().__init__(
            input_size=input_size,
            output_size=output_size,
            bias=bias,
            gather_output=False,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            quant_config=quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
1053

1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
    def validate_shard_id(self, loaded_shard_id: str | None):
        if loaded_shard_id is None:
            return
        if isinstance(loaded_shard_id, str):
            if loaded_shard_id not in ["q", "k", "v"]:
                raise ValueError(
                    "Shard id for QKVParallelLinear should be 'q', 'k', or 'v', "
                    f"got shard id {loaded_shard_id}."
                )
            return
        raise ValueError("This line should not be reached")

1066
1067
1068
1069
1070
    def _get_shard_offset_mapping(self, loaded_shard_id: str):
        shard_offset_mapping = {
            "q": 0,
            "k": self.num_heads * self.head_size,
            "v": (self.num_heads + self.num_kv_heads) * self.head_size,
1071
1072
            "total": (self.num_heads + self.num_kv_heads) * self.head_size
            + self.num_kv_heads * self.v_head_size,
1073
1074
1075
1076
1077
1078
1079
        }
        return shard_offset_mapping.get(loaded_shard_id)

    def _get_shard_size_mapping(self, loaded_shard_id: str):
        shard_size_mapping = {
            "q": self.num_heads * self.head_size,
            "k": self.num_kv_heads * self.head_size,
1080
            "v": self.num_kv_heads * self.v_head_size,
1081
1082
1083
        }
        return shard_size_mapping.get(loaded_shard_id)

1084
1085
1086
    def _load_fused_module_from_checkpoint(
        self, param: BasevLLMParameter, loaded_weight: torch.Tensor
    ):
1087
        """
1088
        Handle special case for models where QKV layers are already
1089
        fused on disk. In this case, we have no shard id. This function
1090
        determines the shard id by splitting these layers and then calls
1091
1092
1093
1094
1095
1096
1097
1098
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """
        shard_offsets = [
            # (shard_id, shard_offset, shard_size)
            ("q", 0, self.total_num_heads * self.head_size),
1099
1100
1101
1102
1103
1104
1105
1106
            (
                "k",
                self.total_num_heads * self.head_size,
                self.total_num_kv_heads * self.head_size,
            ),
            (
                "v",
                (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
1107
                self.total_num_kv_heads * self.v_head_size,
1108
            ),
1109
1110
1111
1112
1113
1114
        ]

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
            if (
                isinstance(param, (PackedColumnParameter, PackedvLLMParameter))
                and param.packed_dim == param.output_dim
            ):
                shard_size, shard_offset = param.adjust_shard_indexes_for_packing(
                    shard_size=shard_size, shard_offset=shard_offset
                )

            loaded_weight_shard = loaded_weight.narrow(
                param.output_dim, shard_offset, shard_size
            )
1126
1127
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

1128
1129
1130
1131
    def weight_loader_v2(
        self,
        param: BasevLLMParameter,
        loaded_weight: torch.Tensor,
1132
        loaded_shard_id: str | None = None,
1133
    ):
1134
        self.validate_shard_id(loaded_shard_id)
1135
        if loaded_shard_id is None:  # special case for certain models
1136
            if isinstance(param, PerTensorScaleParameter):
1137
1138
1139
                param.load_qkv_weight(
                    loaded_weight=loaded_weight, shard_id=0, tp_rank=self.tp_rank
                )
1140
                return
1141
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
1142
                param.load_qkv_weight(loaded_weight=loaded_weight, tp_rank=self.tp_rank)
1143
                return
1144
            # TODO: @dsikka - move to parameter.py
1145
1146
1147
1148
1149
1150
1151
1152
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id in ["q", "k", "v"]

        shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
        shard_size = self._get_shard_size_mapping(loaded_shard_id)

1153
        if isinstance(param, BlockQuantScaleParameter):
1154
1155
1156
1157
            weight_block_size = getattr(self, "weight_block_size", None)
            shard_size, shard_offset = adjust_block_scale_shard(
                weight_block_size, shard_size, shard_offset
            )
1158

1159
1160
1161
1162
1163
1164
1165
1166
        param.load_qkv_weight(
            loaded_weight=loaded_weight,
            num_heads=self.num_kv_head_replicas,
            shard_id=loaded_shard_id,
            shard_offset=shard_offset,
            shard_size=shard_size,
            tp_rank=self.tp_rank,
        )
1167

1168
1169
1170
1171
    def weight_loader(
        self,
        param: Parameter,
        loaded_weight: torch.Tensor,
1172
        loaded_shard_id: str | None = None,
1173
    ):
1174
        self.validate_shard_id(loaded_shard_id)
1175
1176
1177
1178
        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
1179
        if is_gguf_weight_type:
1180
            idx_map = {"q": 0, "k": 1, "v": 2}
1181
1182
1183
1184
            if loaded_shard_id is not None:
                param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
1185
                param.shard_weight_type = {k: loaded_weight.item() for k in idx_map}
1186
1187
            return

1188
1189
        if is_gguf_weight:
            output_dim = getattr(param, "output_dim", None)
1190
1191
            shard_size = loaded_weight.size(output_dim) // self.tp_size
            start_idx = self.tp_rank * shard_size
1192

1193
            if loaded_shard_id is not None:
1194
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
1195
1196
1197
1198
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
1199

1200
1201
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
1202

1203
1204
        # Special case for per-tensor scales in fused case.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
1205

1206
        if loaded_shard_id is None:
1207
1208
            # Loaded weight is already fused on disk (qkv).
            # (e.g., Phi-3's qkv_proj).
1209
            if output_dim is None:
1210
                if needs_scalar_to_array:
1211
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
1212
1213
                        param_data, loaded_weight, 0
                    )
1214

1215
1216
1217
1218
1219
1220
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            shard_offsets = [
                # (shard_id, shard_offset, shard_size)
                ("q", 0, self.total_num_heads * self.head_size),
1221
1222
1223
1224
1225
1226
1227
1228
                (
                    "k",
                    self.total_num_heads * self.head_size,
                    self.total_num_kv_heads * self.head_size,
                ),
                (
                    "v",
                    (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
1229
                    self.total_num_kv_heads * self.v_head_size,
1230
                ),
1231
            ]
1232
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1233

1234
1235
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
1236
                # Special case for Quantized Weights.
1237
1238
1239
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
1240
1241
                    shard_size = shard_size // param.packed_factor
                    shard_offset = shard_offset // param.packed_factor
1242

1243
                    # Special case for Marlin.
1244
                    shard_size, shard_offset = adjust_marlin_shard(
1245
1246
                        param, shard_size, shard_offset
                    )
1247

1248
1249
1250
                if use_bitsandbytes_4bit:
                    orig_qkv_offsets = {
                        "q": (0, self.total_num_heads * self.head_size),
1251
1252
1253
1254
1255
1256
1257
                        "k": (
                            self.total_num_heads * self.head_size,
                            self.total_num_kv_heads * self.head_size,
                        ),
                        "v": (
                            (self.total_num_heads + self.total_num_kv_heads)
                            * self.head_size,
1258
                            self.total_num_kv_heads * self.v_head_size,
1259
1260
                        ),
                        "total": (
1261
1262
1263
                            (self.total_num_heads + self.total_num_kv_heads)
                            * self.head_size
                            + self.total_num_kv_heads * self.v_head_size,
1264
1265
                            0,
                        ),
1266
1267
1268
                    }

                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
1269
1270
                        param, orig_qkv_offsets, shard_id
                    )
1271

1272
                loaded_weight_shard = loaded_weight.narrow(
1273
1274
                    output_dim, shard_offset, shard_size
                )
1275
1276
1277
1278
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id in ["q", "k", "v"]
1279
1280

        # If output dim is defined, use the default loading process.
1281
1282
1283
1284
1285
1286
1287
1288
        if output_dim is not None:
            if loaded_shard_id == "q":
                shard_offset = 0
                shard_size = self.num_heads * self.head_size
            elif loaded_shard_id == "k":
                shard_offset = self.num_heads * self.head_size
                shard_size = self.num_kv_heads * self.head_size
            elif loaded_shard_id == "v":
1289
                shard_offset = (self.num_heads + self.num_kv_heads) * self.head_size
1290
                shard_size = self.num_kv_heads * self.v_head_size
1291
1292
1293
1294
1295
1296
1297

            if isinstance(param, BlockQuantScaleParameter):
                weight_block_size = getattr(self, "weight_block_size", None)
                shard_size, shard_offset = adjust_block_scale_shard(
                    weight_block_size, shard_size, shard_offset
                )

1298
            # Special case for Quantized Weights.
1299
1300
1301
1302
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
1303
1304
                shard_size = shard_size // param.packed_factor
                shard_offset = shard_offset // param.packed_factor
1305

1306
                # Special case for Marlin.
1307
                shard_size, shard_offset = adjust_marlin_shard(
1308
1309
                    param, shard_size, shard_offset
                )
1310

1311
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1312
1313
1314
1315
1316
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

1317
            if use_bitsandbytes_4bit:
1318
1319
                orig_qkv_offsets = {
                    "q": (0, self.num_heads * self.head_size),
1320
1321
1322
1323
1324
1325
                    "k": (
                        self.num_heads * self.head_size,
                        self.num_kv_heads * self.head_size,
                    ),
                    "v": (
                        (self.num_heads + self.num_kv_heads) * self.head_size,
1326
                        self.num_kv_heads * self.v_head_size,
1327
1328
                    ),
                    "total": (
1329
1330
                        (self.num_heads + self.num_kv_heads) * self.head_size
                        + self.num_kv_heads * self.v_head_size,
1331
1332
                        0,
                    ),
1333
                }
1334
                shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
1335
1336
                    param, orig_qkv_offsets, loaded_shard_id
                )
1337

1338
            param_data = param_data.narrow(output_dim, shard_offset, shard_size)
1339
            if loaded_shard_id == "q":
1340
                shard_rank = self.tp_rank
1341
            else:
1342
1343
                shard_rank = self.tp_rank // self.num_kv_head_replicas
            start_idx = shard_rank * shard_size
1344

1345
            if not is_sharded_weight:
1346
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
1347

1348
1349
1350
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
1351
1352
                param_data, loaded_weight, loaded_shard_id
            )
1353
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
1354
1355
1356
1357
1358
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "QKVParallelLinear, assume the weight is the same "
1359
1360
                    "for all partitions."
                )
1361

1362
1363
1364
1365
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)


1366
# --8<-- [start:row_parallel_linear]
1367
@PluggableLayer.register("row_parallel_linear")
1368
class RowParallelLinear(LinearBase):
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
    """Linear layer with row parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its first dimension and X along its second dimension as:
               -   -
              | A_1 |
              | .   |
          A = | .   |        X = [X_1, ..., X_p]
              | .   |
              | A_p |
               -   -
    Arguments:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias. Note that bias is not parallelized.
        input_is_parallel: If true, we assume that the input is already
                           split across the GPUs and we do not split
                           again.
        skip_bias_add: This was added to enable performance optimization where
                       bias can be fused with other element-wise operations.
                       We skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
1391
1392
1393
        reduce_results: If true, call all-reduce on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y = X_iA_i
1394
        quant_config: Quantization configure.
1395
1396
1397
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.down_proj)
        return_bias: If true, return bias together with outputs in forward pass.
1398
        disable_tp: If true, weights matrix won't be sharded through tp rank.
1399
1400
    """

1401
1402
    # --8<-- [end:row_parallel_linear]

1403
1404
1405
1406
1407
1408
1409
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        input_is_parallel: bool = True,
        skip_bias_add: bool = False,
1410
        params_dtype: torch.dtype | None = None,
1411
        reduce_results: bool = True,
1412
        quant_config: QuantizationConfig | None = None,
1413
1414
1415
        prefix: str = "",
        *,
        return_bias: bool = True,
1416
        disable_tp: bool = False,
1417
    ):
1418
        # Divide the weight matrix along the first dimension.
1419
1420
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
1421
1422
1423
1424
        self.input_size_per_partition = divide(input_size, self.tp_size)
        self.output_size_per_partition = output_size
        self.output_partition_sizes = [output_size]

1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
        super().__init__(
            input_size,
            output_size,
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
1435

1436
1437
1438
        self.input_is_parallel = input_is_parallel
        self.reduce_results = reduce_results

1439
        assert self.quant_method is not None
1440
1441
1442
        self.quant_method.create_weights(
            layer=self,
            input_size_per_partition=self.input_size_per_partition,
1443
            output_partition_sizes=self.output_partition_sizes,
1444
1445
1446
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
1447
            weight_loader=(
1448
1449
1450
1451
1452
                self.weight_loader_v2
                if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
                else self.weight_loader
            ),
        )
1453
        if not reduce_results and (bias and not skip_bias_add):
1454
1455
1456
1457
            raise ValueError(
                "When not reduce the results, adding bias to the "
                "results can lead to incorrect results"
            )
1458
1459

        if bias:
1460
1461
1462
1463
1464
1465
1466
1467
            self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype))
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
1468
1469
        else:
            self.register_parameter("bias", None)
1470
        self.update_param_tp_status()
1471
1472
1473

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        input_dim = getattr(param, "input_dim", None)
1474
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1475
1476
1477
1478
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
1479
1480
1481
1482
1483
1484
1485
1486
1487

        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
1488
1489
            weight_shape = list(loaded_weight.shape)
            if input_dim:
1490
                weight_shape[input_dim] = weight_shape[input_dim] // self.tp_size
1491
            param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
1492

1493
        param_data = param.data
1494
        if input_dim is not None and not is_sharded_weight:
1495
            shard_size = param_data.shape[input_dim]
1496
            start_idx = self.tp_rank * shard_size
1497
            loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)
1498

1499
1500
1501
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
1502
1503
            loaded_weight = loaded_weight.reshape(1)

1504
1505
1506
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

1507
    def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor):
1508
1509
1510
1511
1512
1513
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)

1514
1515
        param.load_row_parallel_weight(loaded_weight=loaded_weight)

1516
    def forward(
1517
1518
        self,
        input_,
1519
    ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
1520
1521
1522
1523
        if self.input_is_parallel:
            input_parallel = input_
        else:
            splitted_input = split_tensor_along_last_dim(
1524
1525
                input_, num_partitions=self.tp_size
            )
1526
            input_parallel = splitted_input[self.tp_rank].contiguous()
1527
1528

        # Matrix multiply.
1529
        assert self.quant_method is not None
1530
1531
1532
        # Only fuse bias add into GEMM for rank 0 (this ensures that
        # bias will not get added more than once in TP>1 case)
        bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
1533
1534
        output_parallel = self.quant_method.apply(self, input_parallel, bias_)

1535
        if self.reduce_results and self.tp_size > 1:
1536
            output = tensor_model_parallel_all_reduce(output_parallel)
1537
        else:
1538
1539
            output = output_parallel

1540
1541
        if not self.return_bias:
            return output
1542
        output_bias = self.bias if self.skip_bias_add else None
1543
        return output, output_bias
1544
1545

    def extra_repr(self) -> str:
1546
        s = f"in_features={self.input_size_per_partition}"
1547
1548
1549
1550
1551
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={self.tp_size}"
        s += f", reduce_results={self.reduce_results}"
        return s