linear.py 59.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
from abc import abstractmethod
6
7

import torch
8
from torch.nn.parameter import Parameter, UninitializedParameter
9

10
11
12
13
14
15
16
17
from vllm.distributed import (
    divide,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
    split_tensor_along_last_dim,
    tensor_model_parallel_all_gather,
    tensor_model_parallel_all_reduce,
)
18
from vllm.logger import init_logger
19
from vllm.model_executor.custom_op import PluggableLayer
20
21
22
23
from vllm.model_executor.layers.batch_invariant import (
    linear_batch_invariant,
    vllm_is_batch_invariant,
)
24
from vllm.model_executor.layers.quantization.base_config import (
25
26
27
    QuantizationConfig,
    QuantizeMethodBase,
)
28
29
30
from vllm.model_executor.layers.utils import (
    dispatch_unquantized_gemm,
)
31
32
33
34
35
36
37
38
39
from vllm.model_executor.parameter import (
    BasevLLMParameter,
    BlockQuantScaleParameter,
    ModelWeightParameter,
    PackedColumnParameter,
    PackedvLLMParameter,
    PerTensorScaleParameter,
    RowvLLMParameter,
)
40
from vllm.model_executor.utils import set_weight_attrs
41
from vllm.platforms import current_platform
42
43
44

logger = init_logger(__name__)

45
WEIGHT_LOADER_V2_SUPPORTED = [
46
    "UnquantizedLinearMethod",
47
    "CompressedTensorsLinearMethod",
48
    "CompressedTensorsLinearTransformMethod",
49
50
51
52
53
54
55
56
57
58
    "AWQMarlinLinearMethod",
    "AWQLinearMethod",
    "GPTQMarlinLinearMethod",
    "Fp8LinearMethod",
    "MarlinLinearMethod",
    "GPTQMarlin24LinearMethod",
    "TPUInt8LinearMethod",
    "GPTQLinearMethod",
    "FBGEMMFp8LinearMethod",
    "ModelOptFp8LinearMethod",
59
60
    "ModelOptFp8PcPtLinearMethod",
    "ModelOptFp8PbWoLinearMethod",
61
62
    "QuarkLinearMethod",
    "ModelOptNvFp4LinearMethod",
63
    "PetitNvFp4LinearMethod",
64
]
65

66

67
68
69
70
71
72
def register_weight_loader_v2_supported_method(cls):
    """Decorator to register a LinearMethod as supporting weight_loader_v2."""
    WEIGHT_LOADER_V2_SUPPORTED.append(cls.__name__)
    return cls


73
74
75
76
77
78
def adjust_marlin_shard(
    param: Parameter,
    shard_size: int,
    shard_offset: int,
) -> tuple[int, int]:
    marlin_tile_size: int | None = getattr(param, "marlin_tile_size", None)
79
80
81
82
83
84
    if marlin_tile_size is None:
        return shard_size, shard_offset

    return shard_size * marlin_tile_size, shard_offset * marlin_tile_size


85
86
87
88
89
def adjust_block_scale_shard(
    weight_block_size: tuple[int, ...] | None,
    shard_size: int,
    shard_offset: int,
) -> tuple[int, int]:
90
91
92
93
94
95
96
    assert weight_block_size is not None
    block_n = weight_block_size[0]
    shard_offset = (shard_offset + block_n - 1) // block_n
    shard_size = (shard_size + block_n - 1) // block_n
    return shard_size, shard_offset


97
def adjust_bitsandbytes_4bit_shard(
98
99
100
    param: Parameter,
    shard_offsets: dict[str, tuple[int, int]],
    loaded_shard_id: str,
101
) -> tuple[int, int]:
102
103
    """Adjust the quantization offsets and sizes for BitsAndBytes sharding."""

104
105
    total, _ = shard_offsets["total"]
    orig_offset, orig_size = shard_offsets[loaded_shard_id]
106
107
108
109
110
111
112
113

    quantized_total = param.data.shape[0]
    quantized_offset = orig_offset * quantized_total // total
    quantized_size = orig_size * quantized_total // total

    return quantized_size, quantized_offset


114
115
116
117
118
def adjust_scalar_to_fused_array(
    param_data: torch.Tensor,
    loaded_weight: torch.Tensor,
    shard_id: int | str,
) -> tuple[torch.Tensor, torch.Tensor]:
119
120
    """For fused modules (QKV and MLP) we have an array of length
    N that holds 1 scale for each "logical" matrix. So the param
121
122
    is an array of length N. The loaded_weight corresponds to
    one of the shards on disk. Here, we slice the param based on
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
    the shard_id for loading.
    """
    qkv_idxs = {"q": 0, "k": 1, "v": 2}

    if isinstance(shard_id, str):
        shard_id = qkv_idxs[shard_id]
    elif not isinstance(shard_id, int):
        raise ValueError(f"Unknown Shard Id {shard_id}")

    # AutoFP8 scales do not have a shape
    # compressed-tensors scales do have a shape
    if len(loaded_weight.shape) != 0:
        assert loaded_weight.shape[0] == 1
        loaded_weight = loaded_weight[0]

138
    return param_data[shard_id], loaded_weight
139
140


141
class LinearMethodBase(QuantizeMethodBase):
142
143
144
    """Base class for different (maybe quantized) linear methods."""

    @abstractmethod
145
146
147
148
149
150
151
152
153
154
155
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        """Create weights for a linear layer.
156
           The weights will be set as attributes of the layer.
157

158
159
160
        Args:
            layer: The layer that is using the LinearMethodBase factory.
            input_size_per_partition: Size of the weight input dim on rank X.
161
            output_partition_sizes: Sizes of the output dim of each logical
162
163
164
165
166
167
                weight on rank X. E.g., output_partition_sizes for QKVLinear
                is a list contains the width of Wq, Wk, Wv on rank X.
            input_size: Size of the input dim of the weight across all ranks.
            output_size: Size of the output dim of the weight across all ranks.
            params_dtype: Datatype of the parameters.
        """
168
169
170
        raise NotImplementedError

    @abstractmethod
171
172
173
174
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
175
        bias: torch.Tensor | None = None,
176
    ) -> torch.Tensor:
177
178
        """Apply the weights in layer to the input tensor.
        Expects create_weights to have been called before on the layer."""
179
180
181
182
        raise NotImplementedError


class UnquantizedLinearMethod(LinearMethodBase):
183
    """Linear method without quantization."""
184

185
186
187
188
189
190
191
192
193
194
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
195
196
197
198
        # This method creates unquantized linear weights.
        # The weights are not quantized, and they are not sharded.
        # The amount of memory allocated for the weights is
        # sum(output_partition_sizes) * input_size_per_partition.
199
200
201
202
203
204
205
206
207
208
209
        weight_loader = extra_weight_attrs.pop("weight_loader")
        weight = ModelWeightParameter(
            data=torch.empty(
                sum(output_partition_sizes),
                input_size_per_partition,
                dtype=params_dtype,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
210

211
212
        layer.register_parameter("weight", weight)
        set_weight_attrs(weight, extra_weight_attrs)
213

214
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
215
        if current_platform.is_cpu():
216
            from vllm.model_executor.layers.utils import dispatch_cpu_unquantized_gemm
217

218
            dispatch_cpu_unquantized_gemm(layer, remove_weight=True)
219

220
221
222
223
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
224
        bias: torch.Tensor | None = None,
225
    ) -> torch.Tensor:
226
        if vllm_is_batch_invariant() and current_platform.is_cuda_alike():
227
            return linear_batch_invariant(x, layer.weight, bias)
228
        return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
229
230


231
class LinearBase(PluggableLayer):
232
    """Base linear layer.
233
234
235
236
237
238

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
239
        quant_config: Quantization configure.
240
        prefix: Prefix for parameter names.
241
        return_bias: If true, return bias together with outputs in forward pass.
242
        disable_tp: If true, tensor parallelism will be disabled for this layer.
243
244
245
246
247
248
249
    """

    def __init__(
        self,
        input_size: int,
        output_size: int,
        skip_bias_add: bool = False,
250
251
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
252
        prefix: str = "",
253
254
        *,
        return_bias: bool = True,
255
        disable_tp: bool = False,
256
257
258
259
260
261
262
263
264
265
    ):
        super().__init__()

        # Keep input parameters
        self.input_size = input_size
        self.output_size = output_size
        self.skip_bias_add = skip_bias_add
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype
266
267
        self.quant_config = quant_config
        self.prefix = prefix
268
        self.allow_fp8_block_shape_mismatch = False
269
        if quant_config is None:
270
            self.quant_method: QuantizeMethodBase | None = UnquantizedLinearMethod()
271
        else:
272
            self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
273
        self.return_bias = return_bias
274
        self.disable_tp = disable_tp
275
276
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
277

278
    def update_param_tp_status(self):
279
280
281
282
        for param in self.parameters():
            if isinstance(param, BasevLLMParameter):
                param.tp_rank = self.tp_rank
                param.tp_size = self.tp_size
283
284


285
# --8<-- [start:replicated_linear]
286
@PluggableLayer.register("replicated_linear")
287
288
289
290
291
292
293
294
295
296
class ReplicatedLinear(LinearBase):
    """Replicated linear layer.

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        bias: If true, add bias.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
        quant_config: Quantization configure.
297
298
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
299
        return_bias: If true, return bias together with outputs in forward pass.
300
        disable_tp: Take no effect for replicated linear layers.
301
302
    """

303
304
    # --8<-- [end:replicated_linear]

305
306
307
308
309
310
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        skip_bias_add: bool = False,
311
312
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
313
314
315
        prefix: str = "",
        *,
        return_bias: bool = True,
316
        disable_tp: bool = False,
317
    ):
318
319
320
321
322
323
        # If MergedReplicatedLinear, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = self.output_sizes
        else:
            self.output_partition_sizes = [output_size]

324
325
326
327
328
329
330
331
332
333
        super().__init__(
            input_size,
            output_size,
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
334

335
336
        # All the linear layer supports quant method.
        assert self.quant_method is not None
337
338
339
340
341
342
343
344
345
        self.quant_method.create_weights(
            self,
            self.input_size,
            self.output_partition_sizes,
            self.input_size,
            self.output_size,
            self.params_dtype,
            weight_loader=self.weight_loader,
        )
346

347
348
        if bias:
            self.bias = Parameter(
349
350
351
352
353
354
355
356
357
                torch.empty(self.output_size, dtype=self.params_dtype)
            )
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
358
359
360
        else:
            self.register_parameter("bias", None)

361
362
363
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        # If the weight on disk does not have a shape, give it one
        # (such scales for AutoFp8).
364
365
366
367
368
369
370
371
372
373
374
        # Special case for GGUF

        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
            param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)

375
376
377
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)

378
379
        assert param.size() == loaded_weight.size(), (
            f"Tried to load weights of size {loaded_weight.size()}"
380
381
            f"to a parameter of size {param.size()}"
        )
382
383
        param.data.copy_(loaded_weight)

384
    def forward(
385
386
        self,
        x: torch.Tensor,
387
    ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
388
        bias = self.bias if not self.skip_bias_add else None
389
        assert self.quant_method is not None
390

391
        output = self.quant_method.apply(self, x, bias)
392

393
394
        if not self.return_bias:
            return output
395
        output_bias = self.bias if self.skip_bias_add else None
396
397
        return output, output_bias

398
399
400
401
402
403
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        return s

404

405
# --8<-- [start:column_parallel_linear]
406
@PluggableLayer.register("column_parallel_linear")
407
class ColumnParallelLinear(LinearBase):
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
    """Linear layer with column parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its second dimension as A = [A_1, ..., A_p].

    Args:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y_i = XA_i
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
424
        quant_config: Quantization configure.
425
        prefix: The name of the layer in the state dict, including all parents
426
                        (e.g. model.layers.0.qkv_proj)
427
428
        return_bias: If true, return bias together with outputs in forward pass.
        disable_tp: If true, weights matrix won't be sharded through tp rank.
429
430
    """

431
432
    # --8<-- [end:column_parallel_linear]

433
434
435
436
437
438
439
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
440
441
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
442
443
444
        prefix: str = "",
        *,
        return_bias: bool = True,
445
        disable_tp: bool = False,
446
    ):
447
        # Divide the weight matrix along the last dimension.
448
449
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
450
451
        self.input_size_per_partition = input_size
        self.output_size_per_partition = divide(output_size, self.tp_size)
452
453
454
455
        self.output_partition_sizes = [self.output_size_per_partition]
        # If QKV or MergedColumn, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = [
456
                divide(output_size, self.tp_size) for output_size in self.output_sizes
457
458
            ]

459
460
461
462
463
464
465
466
467
468
        super().__init__(
            input_size,
            output_size,
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
469

470
        self._maybe_allow_fp8_block_shape_mismatch()
471
472
473
        self.gather_output = gather_output

        assert self.quant_method is not None
474
475
        self.quant_method.create_weights(
            layer=self,
476
            input_size_per_partition=self.input_size_per_partition,
477
478
479
480
            output_partition_sizes=self.output_partition_sizes,
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
481
            weight_loader=(
482
483
484
485
486
                self.weight_loader_v2
                if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
                else self.weight_loader
            ),
        )
487
488
        if bias:
            self.bias = Parameter(
489
490
491
492
493
494
495
496
497
                torch.empty(self.output_size_per_partition, dtype=params_dtype)
            )
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
498
499
        else:
            self.register_parameter("bias", None)
500
        self.update_param_tp_status()
501

502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
    def _maybe_allow_fp8_block_shape_mismatch(self) -> None:
        quant_config = getattr(self, "quant_config", None)
        weight_block = getattr(quant_config, "weight_block_size", None)
        if (
            weight_block is None
            or len(weight_block) < 1
            or len(self.output_partition_sizes) <= 1
        ):
            return

        try:
            block_n = int(weight_block[0])
        except (ValueError, TypeError):
            return

        if block_n <= 0:
            return

        if any(size % block_n != 0 for size in self.output_partition_sizes):
            self.allow_fp8_block_shape_mismatch = True
            logger.debug(
                "Allowing FP8 block shape mismatch for %s (block_n=%d, partitions=%s)",
                getattr(self, "prefix", "<unknown>"),
                block_n,
                self.output_partition_sizes,
            )

529
530
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        output_dim = getattr(param, "output_dim", None)
531

532
533
534
535
536
537
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

538
539
540
541
542
543
544
545
        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
546
547
            final_shape = list(loaded_weight.shape)
            if output_dim is not None:
548
                assert final_shape[output_dim] % self.tp_size == 0
549
                final_shape[output_dim] = final_shape[output_dim] // self.tp_size
550
            param.materialize(final_shape, dtype=loaded_weight.dtype)
551

552
        param_data = param.data
553
        if output_dim is not None and not is_sharded_weight:
554
            shard_size = param_data.shape[output_dim]
555
            start_idx = self.tp_rank * shard_size
556
            loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
557
558
559
560
561

        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)
562

563
564
565
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

566
    def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor):
567
568
569
570
571
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)
572
573
        param.load_column_parallel_weight(loaded_weight=loaded_weight)

574
    def forward(
575
576
        self,
        input_,
577
    ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
578
579
580
        bias = self.bias if not self.skip_bias_add else None

        # Matrix multiply.
581
        assert self.quant_method is not None
582
        output_parallel = self.quant_method.apply(self, input_, bias)
583

584
        if self.gather_output and self.tp_size > 1:
585
586
587
588
            # All-gather across the partitions.
            output = tensor_model_parallel_all_gather(output_parallel)
        else:
            output = output_parallel
589

590
591
        if not self.return_bias:
            return output
592
        output_bias = self.bias if self.skip_bias_add else None
593
594
        return output, output_bias

595
596
597
598
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size_per_partition}"
        s += f", bias={self.bias is not None}"
599
        s += f", tp_size={self.tp_size}"
600
601
602
        s += f", gather_output={self.gather_output}"
        return s

603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621

class MergedColumnParallelLinear(ColumnParallelLinear):
    """Packed linear layers with column parallelism.

    Similar to ColumnParallelLinear, but the weight matrix is concatenated
    along the output dimension. When the weight matrix is loaded, the
    different partitions are sharded separately.

    Args:
        input_size: input dimension of the linear layer.
        output_sizes: list of output dimensions of the linear layer.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make the output
                       available to all GPUs, otherwise, every GPU will have
                       its own output.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
622
        quant_config: Quantization configure.
623
624
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
625
        return_bias: If true, return bias together with outputs in forward pass.
626
627
        disable_tp: If true, all weights matrix won't be sharded, this layer
                    will be treated as a "Replicated" MergedLinear.
628
629
    """

630
631
632
633
634
635
636
    def __init__(
        self,
        input_size: int,
        output_sizes: list[int],
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
637
638
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
639
640
641
        prefix: str = "",
        *,
        return_bias: bool = True,
642
        disable_tp: bool = False,
643
    ):
644
        self.output_sizes = output_sizes
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0

        assert all(output_size % self.tp_size == 0 for output_size in output_sizes)
        super().__init__(
            input_size=input_size,
            output_size=sum(output_sizes),
            bias=bias,
            gather_output=gather_output,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            quant_config=quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
James Fleming's avatar
James Fleming committed
661

662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
    def validate_shard_id(self, loaded_shard_id: int | tuple[int, ...] | None):
        if loaded_shard_id is None:
            return
        if isinstance(loaded_shard_id, tuple):
            for idx in loaded_shard_id:
                if not (0 <= idx < len(self.output_sizes)):
                    raise ValueError(
                        f"Shard id index {idx} should be between 0 and "
                        f"{len(self.output_sizes) - 1}. Got shard id {loaded_shard_id}."
                    )
            if len(loaded_shard_id) > 1 and any(
                b - a != 1 for a, b in zip(loaded_shard_id[:-1], loaded_shard_id[1:])
            ):
                raise ValueError(
                    "Shard id with multiple indices should be consecutive. "
                    f"Got shard id {loaded_shard_id}."
                )
            return
        elif isinstance(loaded_shard_id, int):
            if loaded_shard_id < 0 or loaded_shard_id >= len(self.output_sizes):
                raise ValueError(
                    f"Shard id should be between 0 and {len(self.output_sizes) - 1}. "
                    f"Got shard id {loaded_shard_id}."
                )
            return
        raise ValueError("This line should not be reached")

689
690
691
692
    def weight_loader(
        self,
        param: Parameter,
        loaded_weight: torch.Tensor,
693
        loaded_shard_id: tuple[int, ...] | int | None = None,
694
    ):
695
        self.validate_shard_id(loaded_shard_id)
696
697
698
699
        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
700
701
702
703
704
705
        if isinstance(loaded_shard_id, tuple) and (
            is_gguf_weight or is_gguf_weight_type
        ):
            raise NotImplementedError(
                "Shard id with multiple indices is not supported for GGUF."
            )
706
        if is_gguf_weight_type:
707
708
709
710
711
            if loaded_shard_id is not None:
                param.data[loaded_shard_id].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
                param.shard_weight_type = {
712
                    i: loaded_weight.item() for i, _ in enumerate(self.output_sizes)
713
                }
714
715
            return

716
717
        if is_gguf_weight:
            output_dim = getattr(param, "output_dim", None)
718
719
            shard_size = loaded_weight.size(output_dim) // self.tp_size
            start_idx = self.tp_rank * shard_size
720

721
            if loaded_shard_id is not None:
722
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
723
724
725
726
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
727

728
729
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
730
731
        # Special case for per-tensor scale to load scalar into fused array.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
732

733
        if loaded_shard_id is None or isinstance(loaded_shard_id, tuple):
734
735
            # Loaded weight is already fused on disk (mlp).
            # (e.g., Phi-3's gate_up_proj).
736
            if output_dim is None:
737
                if needs_scalar_to_array:
738
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
739
740
                        param_data, loaded_weight, 0
                    )
741

742
743
744
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
745
746
747
748
749
750

            output_sizes = (
                self.output_sizes[loaded_shard_id[0] : loaded_shard_id[-1] + 1]
                if loaded_shard_id is not None
                else self.output_sizes
            )
751
            current_shard_offset = 0
752
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
753
754
755
756
757
            if (
                use_bitsandbytes_4bit
                and isinstance(loaded_shard_id, tuple)
                and self.tp_size > 1
            ):
758
759
                raise NotImplementedError(
                    "Shard id with multiple indices is not supported "
760
                    "for BNB quantization with TP yet."
761
                )
762
            shard_offsets: list[tuple[int, int, int]] = []
763
            for i, output_size in enumerate(output_sizes):
764
765
766
767
                shard_offsets.append((i, current_shard_offset, output_size))
                current_shard_offset += output_size
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
768
                # Special case for Quantization.
769
770
771
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
772
773
                    shard_size = shard_size // param.packed_factor
                    shard_offset = shard_offset // param.packed_factor
774
                    # Special case for Marlin.
775
                    shard_size, shard_offset = adjust_marlin_shard(
776
777
                        param, shard_size, shard_offset
                    )
778

779
                if use_bitsandbytes_4bit:
780
781
782
783
784
785
786
                    index = list(itertools.accumulate([0] + self.output_sizes))
                    orig_offsets = {
                        str(i): (index[i], size)
                        for i, size in enumerate(self.output_sizes)
                    }
                    orig_offsets["total"] = (self.output_size, 0)
                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
787
788
                        param, orig_offsets, str(shard_id)
                    )
789

790
                loaded_weight_shard = loaded_weight.narrow(
791
792
                    output_dim, shard_offset, shard_size
                )
793
794
795
796
797
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id < len(self.output_sizes)
        if output_dim is not None:
798
799
            shard_offset = sum(self.output_sizes[:loaded_shard_id])
            shard_size = self.output_sizes[loaded_shard_id]
800
801
            shard_offset //= self.tp_size
            shard_size //= self.tp_size
802
803
804
805
806
807
808

            if isinstance(param, BlockQuantScaleParameter):
                weight_block_size = getattr(self, "weight_block_size", None)
                shard_size, shard_offset = adjust_block_scale_shard(
                    weight_block_size, shard_size, shard_offset
                )

809
            # Special case for quantization.
810
811
812
813
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
814
815
                shard_size = shard_size // param.packed_factor
                shard_offset = shard_offset // param.packed_factor
816
                # Special case for Marlin.
817
                shard_size, shard_offset = adjust_marlin_shard(
818
819
                    param, shard_size, shard_offset
                )
820

821
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
822
823
824
825
826
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

827
            if use_bitsandbytes_4bit:
828
829
830
831
832
833
834
835
                index = list(itertools.accumulate([0] + self.output_sizes))
                orig_offsets = {
                    str(i): (index[i], size) for i, size in enumerate(self.output_sizes)
                }
                orig_offsets["total"] = (self.output_size, 0)
                shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
                    param, orig_offsets, str(loaded_shard_id)
                )
836
            param_data = param_data.narrow(output_dim, shard_offset, shard_size)
837
            start_idx = self.tp_rank * shard_size
838
            if not is_sharded_weight:
839
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
840
841
842
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
843
844
                param_data, loaded_weight, loaded_shard_id
            )
845

846
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
847
848
849
850
851
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "MergedColumnParallelLinear, assume the weight is "
852
853
                    "the same for all partitions."
                )
854

855
856
857
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

858
    def _load_fused_module_from_checkpoint(
859
860
861
862
        self,
        param: BasevLLMParameter,
        loaded_weight: torch.Tensor,
        output_sizes: list[int] | None = None,
863
    ):
864
865
866
        """
        Handle special case for models where MLP layers are already
        fused on disk. In this case, we have no shard id. This function
867
        determines the shard id by splitting these layers and then calls
868
869
870
871
872
873
874
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """

        current_shard_offset = 0
875
        shard_offsets: list[tuple[int, int, int]] = []
876
877
        output_sizes = output_sizes or self.output_sizes
        for i, output_size in enumerate(output_sizes):
878
879
880
881
882
883
884
            shard_offsets.append((i, current_shard_offset, output_size))
            current_shard_offset += output_size

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
885
886
887
888
889
890
891
892
893
894
895
            if (
                isinstance(param, (PackedColumnParameter, PackedvLLMParameter))
                and param.packed_dim == param.output_dim
            ):
                shard_size, shard_offset = param.adjust_shard_indexes_for_packing(
                    shard_size=shard_size, shard_offset=shard_offset
                )

            loaded_weight_shard = loaded_weight.narrow(
                param.output_dim, shard_offset, shard_size
            )
896
897
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

898
899
900
901
    def weight_loader_v2(
        self,
        param: BasevLLMParameter,
        loaded_weight: torch.Tensor,
902
        loaded_shard_id: tuple[int, ...] | int | None = None,
903
    ):
904
        self.validate_shard_id(loaded_shard_id)
905
        if loaded_shard_id is None or isinstance(loaded_shard_id, tuple):
906
            if isinstance(param, PerTensorScaleParameter):
907
                param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0)
908
                return
909
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
910
                param.load_merged_column_weight(loaded_weight=loaded_weight)
911
                return
912
913
914
915
916
917
918
919
920
921
922
            output_sizes = (
                [self.output_sizes[idx] for idx in loaded_shard_id]
                if loaded_shard_id
                else None
            )
            if isinstance(param, BlockQuantScaleParameter):
                weight_block_size = getattr(self, "weight_block_size", None)
                output_sizes = [
                    adjust_block_scale_shard(weight_block_size, size, 0)[0]
                    for size in (output_sizes or self.output_sizes)
                ]
923
            # TODO: @dsikka - move to parameter.py
924
925
926
            self._load_fused_module_from_checkpoint(
                param, loaded_weight, output_sizes=output_sizes
            )
927
928
929
930
            return

        assert loaded_shard_id < len(self.output_sizes)

931
932
        shard_offset = sum(self.output_sizes[:loaded_shard_id])
        shard_size = self.output_sizes[loaded_shard_id]
933
934
        shard_offset //= self.tp_size
        shard_size //= self.tp_size
935

936
        if isinstance(param, BlockQuantScaleParameter):
937
938
939
            weight_block_size = getattr(self, "weight_block_size", None)
            shard_size, shard_offset = adjust_block_scale_shard(
                weight_block_size, shard_size, shard_offset
940
            )
941

942
943
944
945
946
947
948
        param.load_merged_column_weight(
            loaded_weight=loaded_weight,
            shard_id=loaded_shard_id,
            shard_offset=shard_offset,
            shard_size=shard_size,
            tp_rank=self.tp_rank,
        )
949

950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971

class QKVParallelLinear(ColumnParallelLinear):
    """Linear layers for the attention's QKV transformation.

    Linear layers for the linear transformation of the query, key, and value
    vectors in the attention layer. The weight matrix is concatenated along
    the output dimension. The layer is parallelized along the head dimension.
    When the number of key/value heads is smaller than the number of query
    heads (e.g., multi-query/grouped-query attention), the key/value head may
    be replicated while the query heads are partitioned.

    Args:
        hidden_size: input hidden state size of the transformer.
        head_size: size of each attention head.
        total_num_heads: total number of attention query heads.
        total_num_kv_heads: total number of attention key/value heads. If
                            None, assume total_num_kv_heads = total_num_heads.
        bias: If true, add bias.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
972
        quant_config: Quantization configure.
973
974
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
975
        return_bias: If true, return bias together with outputs in forward pass.
976
        disable_tp: If true, weights matrix won't be sharded through tp rank.
977
978
    """

979
980
981
982
983
    def __init__(
        self,
        hidden_size: int,
        head_size: int,
        total_num_heads: int,
984
        total_num_kv_heads: int | None = None,
985
986
        bias: bool = True,
        skip_bias_add: bool = False,
987
988
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
989
990
991
        prefix: str = "",
        *,
        return_bias: bool = True,
992
        disable_tp: bool = False,
993
        v_head_size: int | None = None,
994
    ):
995
996
        self.hidden_size = hidden_size
        self.head_size = head_size
997
        self.v_head_size = v_head_size if v_head_size is not None else head_size
998
999
1000
1001
1002
        self.total_num_heads = total_num_heads
        if total_num_kv_heads is None:
            total_num_kv_heads = total_num_heads
        self.total_num_kv_heads = total_num_kv_heads
        # Divide the weight matrix along the last dimension.
1003
        tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
1004
1005
1006
        self.num_heads = divide(self.total_num_heads, tp_size)
        if tp_size >= self.total_num_kv_heads:
            self.num_kv_heads = 1
1007
            self.num_kv_head_replicas = divide(tp_size, self.total_num_kv_heads)
1008
1009
1010
1011
        else:
            self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
            self.num_kv_head_replicas = 1
        input_size = self.hidden_size
1012
        output_size = (
1013
1014
1015
1016
            self.num_heads * self.head_size
            + self.num_kv_heads * self.head_size
            + self.num_kv_heads * self.v_head_size
        ) * tp_size
1017
1018
1019
        self.output_sizes = [
            self.num_heads * self.head_size * tp_size,  # q_proj
            self.num_kv_heads * self.head_size * tp_size,  # k_proj
1020
            self.num_kv_heads * self.v_head_size * tp_size,  # v_proj
James Fleming's avatar
James Fleming committed
1021
1022
        ]

1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
        super().__init__(
            input_size=input_size,
            output_size=output_size,
            bias=bias,
            gather_output=False,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            quant_config=quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
1035

1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
    def validate_shard_id(self, loaded_shard_id: str | None):
        if loaded_shard_id is None:
            return
        if isinstance(loaded_shard_id, str):
            if loaded_shard_id not in ["q", "k", "v"]:
                raise ValueError(
                    "Shard id for QKVParallelLinear should be 'q', 'k', or 'v', "
                    f"got shard id {loaded_shard_id}."
                )
            return
        raise ValueError("This line should not be reached")

1048
1049
1050
1051
1052
    def _get_shard_offset_mapping(self, loaded_shard_id: str):
        shard_offset_mapping = {
            "q": 0,
            "k": self.num_heads * self.head_size,
            "v": (self.num_heads + self.num_kv_heads) * self.head_size,
1053
1054
            "total": (self.num_heads + self.num_kv_heads) * self.head_size
            + self.num_kv_heads * self.v_head_size,
1055
1056
1057
1058
1059
1060
1061
        }
        return shard_offset_mapping.get(loaded_shard_id)

    def _get_shard_size_mapping(self, loaded_shard_id: str):
        shard_size_mapping = {
            "q": self.num_heads * self.head_size,
            "k": self.num_kv_heads * self.head_size,
1062
            "v": self.num_kv_heads * self.v_head_size,
1063
1064
1065
        }
        return shard_size_mapping.get(loaded_shard_id)

1066
1067
1068
    def _load_fused_module_from_checkpoint(
        self, param: BasevLLMParameter, loaded_weight: torch.Tensor
    ):
1069
        """
1070
        Handle special case for models where QKV layers are already
1071
        fused on disk. In this case, we have no shard id. This function
1072
        determines the shard id by splitting these layers and then calls
1073
1074
1075
1076
1077
1078
1079
1080
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """
        shard_offsets = [
            # (shard_id, shard_offset, shard_size)
            ("q", 0, self.total_num_heads * self.head_size),
1081
1082
1083
1084
1085
1086
1087
1088
            (
                "k",
                self.total_num_heads * self.head_size,
                self.total_num_kv_heads * self.head_size,
            ),
            (
                "v",
                (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
1089
                self.total_num_kv_heads * self.v_head_size,
1090
            ),
1091
1092
1093
1094
1095
1096
        ]

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
            if (
                isinstance(param, (PackedColumnParameter, PackedvLLMParameter))
                and param.packed_dim == param.output_dim
            ):
                shard_size, shard_offset = param.adjust_shard_indexes_for_packing(
                    shard_size=shard_size, shard_offset=shard_offset
                )

            loaded_weight_shard = loaded_weight.narrow(
                param.output_dim, shard_offset, shard_size
            )
1108
1109
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

1110
1111
1112
1113
    def weight_loader_v2(
        self,
        param: BasevLLMParameter,
        loaded_weight: torch.Tensor,
1114
        loaded_shard_id: str | None = None,
1115
    ):
1116
        self.validate_shard_id(loaded_shard_id)
1117
        if loaded_shard_id is None:  # special case for certain models
1118
            if isinstance(param, PerTensorScaleParameter):
1119
1120
1121
                param.load_qkv_weight(
                    loaded_weight=loaded_weight, shard_id=0, tp_rank=self.tp_rank
                )
1122
                return
1123
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
1124
                param.load_qkv_weight(loaded_weight=loaded_weight, tp_rank=self.tp_rank)
1125
                return
1126
            # TODO: @dsikka - move to parameter.py
1127
1128
1129
1130
1131
1132
1133
1134
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id in ["q", "k", "v"]

        shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
        shard_size = self._get_shard_size_mapping(loaded_shard_id)

1135
        if isinstance(param, BlockQuantScaleParameter):
1136
1137
1138
1139
            weight_block_size = getattr(self, "weight_block_size", None)
            shard_size, shard_offset = adjust_block_scale_shard(
                weight_block_size, shard_size, shard_offset
            )
1140

1141
1142
1143
1144
1145
1146
1147
1148
        param.load_qkv_weight(
            loaded_weight=loaded_weight,
            num_heads=self.num_kv_head_replicas,
            shard_id=loaded_shard_id,
            shard_offset=shard_offset,
            shard_size=shard_size,
            tp_rank=self.tp_rank,
        )
1149

1150
1151
1152
1153
    def weight_loader(
        self,
        param: Parameter,
        loaded_weight: torch.Tensor,
1154
        loaded_shard_id: str | None = None,
1155
    ):
1156
        self.validate_shard_id(loaded_shard_id)
1157
1158
1159
1160
        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
1161
        if is_gguf_weight_type:
1162
            idx_map = {"q": 0, "k": 1, "v": 2}
1163
1164
1165
1166
            if loaded_shard_id is not None:
                param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
1167
                param.shard_weight_type = {k: loaded_weight.item() for k in idx_map}
1168
1169
            return

1170
1171
        if is_gguf_weight:
            output_dim = getattr(param, "output_dim", None)
1172
1173
            shard_size = loaded_weight.size(output_dim) // self.tp_size
            start_idx = self.tp_rank * shard_size
1174

1175
            if loaded_shard_id is not None:
1176
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
1177
1178
1179
1180
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
1181

1182
1183
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
1184

1185
1186
        # Special case for per-tensor scales in fused case.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
1187

1188
        if loaded_shard_id is None:
1189
1190
            # Loaded weight is already fused on disk (qkv).
            # (e.g., Phi-3's qkv_proj).
1191
            if output_dim is None:
1192
                if needs_scalar_to_array:
1193
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
1194
1195
                        param_data, loaded_weight, 0
                    )
1196

1197
1198
1199
1200
1201
1202
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            shard_offsets = [
                # (shard_id, shard_offset, shard_size)
                ("q", 0, self.total_num_heads * self.head_size),
1203
1204
1205
1206
1207
1208
1209
1210
                (
                    "k",
                    self.total_num_heads * self.head_size,
                    self.total_num_kv_heads * self.head_size,
                ),
                (
                    "v",
                    (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
1211
                    self.total_num_kv_heads * self.v_head_size,
1212
                ),
1213
            ]
1214
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1215

1216
1217
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
1218
                # Special case for Quantized Weights.
1219
1220
1221
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
1222
1223
                    shard_size = shard_size // param.packed_factor
                    shard_offset = shard_offset // param.packed_factor
1224

1225
                    # Special case for Marlin.
1226
                    shard_size, shard_offset = adjust_marlin_shard(
1227
1228
                        param, shard_size, shard_offset
                    )
1229

1230
1231
1232
                if use_bitsandbytes_4bit:
                    orig_qkv_offsets = {
                        "q": (0, self.total_num_heads * self.head_size),
1233
1234
1235
1236
1237
1238
1239
                        "k": (
                            self.total_num_heads * self.head_size,
                            self.total_num_kv_heads * self.head_size,
                        ),
                        "v": (
                            (self.total_num_heads + self.total_num_kv_heads)
                            * self.head_size,
1240
                            self.total_num_kv_heads * self.v_head_size,
1241
1242
                        ),
                        "total": (
1243
1244
1245
                            (self.total_num_heads + self.total_num_kv_heads)
                            * self.head_size
                            + self.total_num_kv_heads * self.v_head_size,
1246
1247
                            0,
                        ),
1248
1249
1250
                    }

                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
1251
1252
                        param, orig_qkv_offsets, shard_id
                    )
1253

1254
                loaded_weight_shard = loaded_weight.narrow(
1255
1256
                    output_dim, shard_offset, shard_size
                )
1257
1258
1259
1260
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id in ["q", "k", "v"]
1261
1262

        # If output dim is defined, use the default loading process.
1263
1264
1265
1266
1267
1268
1269
1270
        if output_dim is not None:
            if loaded_shard_id == "q":
                shard_offset = 0
                shard_size = self.num_heads * self.head_size
            elif loaded_shard_id == "k":
                shard_offset = self.num_heads * self.head_size
                shard_size = self.num_kv_heads * self.head_size
            elif loaded_shard_id == "v":
1271
                shard_offset = (self.num_heads + self.num_kv_heads) * self.head_size
1272
                shard_size = self.num_kv_heads * self.v_head_size
1273
1274
1275
1276
1277
1278
1279

            if isinstance(param, BlockQuantScaleParameter):
                weight_block_size = getattr(self, "weight_block_size", None)
                shard_size, shard_offset = adjust_block_scale_shard(
                    weight_block_size, shard_size, shard_offset
                )

1280
            # Special case for Quantized Weights.
1281
1282
1283
1284
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
1285
1286
                shard_size = shard_size // param.packed_factor
                shard_offset = shard_offset // param.packed_factor
1287

1288
                # Special case for Marlin.
1289
                shard_size, shard_offset = adjust_marlin_shard(
1290
1291
                    param, shard_size, shard_offset
                )
1292

1293
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1294
1295
1296
1297
1298
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

1299
            if use_bitsandbytes_4bit:
1300
1301
                orig_qkv_offsets = {
                    "q": (0, self.num_heads * self.head_size),
1302
1303
1304
1305
1306
1307
                    "k": (
                        self.num_heads * self.head_size,
                        self.num_kv_heads * self.head_size,
                    ),
                    "v": (
                        (self.num_heads + self.num_kv_heads) * self.head_size,
1308
                        self.num_kv_heads * self.v_head_size,
1309
1310
                    ),
                    "total": (
1311
1312
                        (self.num_heads + self.num_kv_heads) * self.head_size
                        + self.num_kv_heads * self.v_head_size,
1313
1314
                        0,
                    ),
1315
                }
1316
                shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
1317
1318
                    param, orig_qkv_offsets, loaded_shard_id
                )
1319

1320
            param_data = param_data.narrow(output_dim, shard_offset, shard_size)
1321
            if loaded_shard_id == "q":
1322
                shard_rank = self.tp_rank
1323
            else:
1324
1325
                shard_rank = self.tp_rank // self.num_kv_head_replicas
            start_idx = shard_rank * shard_size
1326

1327
            if not is_sharded_weight:
1328
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
1329

1330
1331
1332
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
1333
1334
                param_data, loaded_weight, loaded_shard_id
            )
1335
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
1336
1337
1338
1339
1340
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "QKVParallelLinear, assume the weight is the same "
1341
1342
                    "for all partitions."
                )
1343

1344
1345
1346
1347
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)


1348
# --8<-- [start:row_parallel_linear]
1349
@PluggableLayer.register("row_parallel_linear")
1350
class RowParallelLinear(LinearBase):
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
    """Linear layer with row parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its first dimension and X along its second dimension as:
               -   -
              | A_1 |
              | .   |
          A = | .   |        X = [X_1, ..., X_p]
              | .   |
              | A_p |
               -   -
    Arguments:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias. Note that bias is not parallelized.
        input_is_parallel: If true, we assume that the input is already
                           split across the GPUs and we do not split
                           again.
        skip_bias_add: This was added to enable performance optimization where
                       bias can be fused with other element-wise operations.
                       We skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
1373
1374
1375
        reduce_results: If true, call all-reduce on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y = X_iA_i
1376
        quant_config: Quantization configure.
1377
1378
1379
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.down_proj)
        return_bias: If true, return bias together with outputs in forward pass.
1380
        disable_tp: If true, weights matrix won't be sharded through tp rank.
1381
1382
    """

1383
1384
    # --8<-- [end:row_parallel_linear]

1385
1386
1387
1388
1389
1390
1391
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        input_is_parallel: bool = True,
        skip_bias_add: bool = False,
1392
        params_dtype: torch.dtype | None = None,
1393
        reduce_results: bool = True,
1394
        quant_config: QuantizationConfig | None = None,
1395
1396
1397
        prefix: str = "",
        *,
        return_bias: bool = True,
1398
        disable_tp: bool = False,
1399
    ):
1400
        # Divide the weight matrix along the first dimension.
1401
1402
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
1403
1404
1405
1406
        self.input_size_per_partition = divide(input_size, self.tp_size)
        self.output_size_per_partition = output_size
        self.output_partition_sizes = [output_size]

1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
        super().__init__(
            input_size,
            output_size,
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
1417

1418
1419
1420
        self.input_is_parallel = input_is_parallel
        self.reduce_results = reduce_results

1421
        assert self.quant_method is not None
1422
1423
1424
        self.quant_method.create_weights(
            layer=self,
            input_size_per_partition=self.input_size_per_partition,
1425
            output_partition_sizes=self.output_partition_sizes,
1426
1427
1428
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
1429
            weight_loader=(
1430
1431
1432
1433
1434
                self.weight_loader_v2
                if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
                else self.weight_loader
            ),
        )
1435
        if not reduce_results and (bias and not skip_bias_add):
1436
1437
1438
1439
            raise ValueError(
                "When not reduce the results, adding bias to the "
                "results can lead to incorrect results"
            )
1440
1441

        if bias:
1442
1443
1444
1445
1446
1447
1448
1449
            self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype))
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
1450
1451
        else:
            self.register_parameter("bias", None)
1452
        self.update_param_tp_status()
1453
1454
1455

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        input_dim = getattr(param, "input_dim", None)
1456
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1457
1458
1459
1460
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
1461
1462
1463
1464
1465
1466
1467
1468
1469

        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
1470
1471
            weight_shape = list(loaded_weight.shape)
            if input_dim:
1472
                weight_shape[input_dim] = weight_shape[input_dim] // self.tp_size
1473
            param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
1474

1475
        param_data = param.data
1476
        if input_dim is not None and not is_sharded_weight:
1477
            shard_size = param_data.shape[input_dim]
1478
            start_idx = self.tp_rank * shard_size
1479
            loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)
1480

1481
1482
1483
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
1484
1485
            loaded_weight = loaded_weight.reshape(1)

1486
1487
1488
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

1489
    def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor):
1490
1491
1492
1493
1494
1495
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)

1496
1497
        param.load_row_parallel_weight(loaded_weight=loaded_weight)

1498
    def forward(
1499
1500
        self,
        input_,
1501
    ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
1502
1503
1504
        if self.input_is_parallel:
            input_parallel = input_
        else:
Jiayi Yan's avatar
Jiayi Yan committed
1505
            split_input = split_tensor_along_last_dim(
1506
1507
                input_, num_partitions=self.tp_size
            )
Jiayi Yan's avatar
Jiayi Yan committed
1508
            input_parallel = split_input[self.tp_rank].contiguous()
1509
1510

        # Matrix multiply.
1511
        assert self.quant_method is not None
1512
1513
1514
        # Only fuse bias add into GEMM for rank 0 (this ensures that
        # bias will not get added more than once in TP>1 case)
        bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
1515
1516
        output_parallel = self.quant_method.apply(self, input_parallel, bias_)

1517
        if self.reduce_results and self.tp_size > 1:
1518
            output = tensor_model_parallel_all_reduce(output_parallel)
1519
        else:
1520
1521
            output = output_parallel

1522
1523
        if not self.return_bias:
            return output
1524
        output_bias = self.bias if self.skip_bias_add else None
1525
        return output, output_bias
1526
1527

    def extra_repr(self) -> str:
1528
        s = f"in_features={self.input_size_per_partition}"
1529
1530
1531
1532
1533
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={self.tp_size}"
        s += f", reduce_results={self.reduce_results}"
        return s