linear.py 56.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
from abc import abstractmethod
6
from typing import Any
7
8

import torch
9
from torch.nn.parameter import Parameter, UninitializedParameter
10

11
12
13
14
15
16
17
18
from vllm.distributed import (
    divide,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
    split_tensor_along_last_dim,
    tensor_model_parallel_all_gather,
    tensor_model_parallel_all_reduce,
)
19
from vllm.logger import init_logger
20
from vllm.model_executor.custom_op import PluggableLayer
21
22
23
24
from vllm.model_executor.layers.batch_invariant import (
    linear_batch_invariant,
    vllm_is_batch_invariant,
)
25
from vllm.model_executor.layers.quantization.base_config import (
26
27
28
    QuantizationConfig,
    QuantizeMethodBase,
)
29
30
31
32
from vllm.model_executor.layers.utils import (
    dispatch_unquantized_gemm,
    is_layer_moe_router_gate,
)
33
34
35
36
37
38
39
40
41
from vllm.model_executor.parameter import (
    BasevLLMParameter,
    BlockQuantScaleParameter,
    ModelWeightParameter,
    PackedColumnParameter,
    PackedvLLMParameter,
    PerTensorScaleParameter,
    RowvLLMParameter,
)
42
from vllm.model_executor.utils import set_weight_attrs
43
from vllm.platforms import current_platform
44
45
46

logger = init_logger(__name__)

47
WEIGHT_LOADER_V2_SUPPORTED = [
48
    "UnquantizedLinearMethod",
49
    "CompressedTensorsLinearMethod",
50
    "CompressedTensorsLinearTransformMethod",
51
52
53
54
55
56
57
58
59
60
    "AWQMarlinLinearMethod",
    "AWQLinearMethod",
    "GPTQMarlinLinearMethod",
    "Fp8LinearMethod",
    "MarlinLinearMethod",
    "GPTQMarlin24LinearMethod",
    "TPUInt8LinearMethod",
    "GPTQLinearMethod",
    "FBGEMMFp8LinearMethod",
    "ModelOptFp8LinearMethod",
61
62
    "ModelOptFp8PcPtLinearMethod",
    "ModelOptFp8PbWoLinearMethod",
63
64
65
66
    "IPEXAWQLinearMethod",
    "IPEXGPTQLinearMethod",
    "QuarkLinearMethod",
    "ModelOptNvFp4LinearMethod",
67
    "PetitNvFp4LinearMethod",
68
]
69

70

71
72
73
74
75
76
77
78
def adjust_marlin_shard(param, shard_size, shard_offset):
    marlin_tile_size = getattr(param, "marlin_tile_size", None)
    if marlin_tile_size is None:
        return shard_size, shard_offset

    return shard_size * marlin_tile_size, shard_offset * marlin_tile_size


79
80
81
82
83
84
85
86
def adjust_block_scale_shard(weight_block_size, shard_size, shard_offset):
    assert weight_block_size is not None
    block_n = weight_block_size[0]
    shard_offset = (shard_offset + block_n - 1) // block_n
    shard_size = (shard_size + block_n - 1) // block_n
    return shard_size, shard_offset


87
88
89
def adjust_bitsandbytes_4bit_shard(
    param: Parameter, shard_offsets: dict[str, tuple[int, int]], loaded_shard_id: str
) -> tuple[int, int]:
90
91
    """Adjust the quantization offsets and sizes for BitsAndBytes sharding."""

92
93
    total, _ = shard_offsets["total"]
    orig_offset, orig_size = shard_offsets[loaded_shard_id]
94
95
96
97
98
99
100
101

    quantized_total = param.data.shape[0]
    quantized_offset = orig_offset * quantized_total // total
    quantized_size = orig_size * quantized_total // total

    return quantized_size, quantized_offset


102
103
104
def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
    """For fused modules (QKV and MLP) we have an array of length
    N that holds 1 scale for each "logical" matrix. So the param
105
106
    is an array of length N. The loaded_weight corresponds to
    one of the shards on disk. Here, we slice the param based on
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
    the shard_id for loading.
    """
    qkv_idxs = {"q": 0, "k": 1, "v": 2}

    if isinstance(shard_id, str):
        shard_id = qkv_idxs[shard_id]
    elif not isinstance(shard_id, int):
        raise ValueError(f"Unknown Shard Id {shard_id}")

    # AutoFP8 scales do not have a shape
    # compressed-tensors scales do have a shape
    if len(loaded_weight.shape) != 0:
        assert loaded_weight.shape[0] == 1
        loaded_weight = loaded_weight[0]

    return param[shard_id], loaded_weight


125
126
127
128
129
130
131
132
# TODO(Isotr0py): We might need a more flexible structure to handle
# bitsandbytes shard offsets.
def left_shift_bitsandbytes_4bit_shard(bnb_weight_attrs: dict[str, Any]):
    """
    Separate the BitsAndBytes 4-bit shard.

    For example, given bnb weight attributes as below:
    {
133
        'bnb_shard_offsets': array([0, 4, 8, 16]),
134
135
136
137
138
        'bnb_quant_state': {0: ..., 1: ..., 2: ...},
    }

    The function will return:
    {
139
        'bnb_shard_offsets': array([0, 4]),
140
141
142
143
144
145
146
147
148
149
150
151
152
153
        'bnb_quant_state': {0: ...},
    }
    and
    {
        'bnb_shard_offsets': array([0, 4, 12]),
        'bnb_quant_state': {0: ..., 1: ...},
    }
    """
    shard_offsets = bnb_weight_attrs["bnb_shard_offsets"]
    offset_l = shard_offsets[:2]
    offset_r = shard_offsets[1:] - shard_offsets[1]
    quant_state_l = {0: bnb_weight_attrs["bnb_quant_state"][0]}
    quant_state_r = {
        i - 1: bnb_weight_attrs["bnb_quant_state"][i]
154
        for i in range(1, len(shard_offsets) - 1)
155
156
157
158
159
160
    }
    left = dict(bnb_shard_offsets=offset_l, bnb_quant_state=quant_state_l)
    right = dict(bnb_shard_offsets=offset_r, bnb_quant_state=quant_state_r)
    return left, right


161
class LinearMethodBase(QuantizeMethodBase):
162
163
164
    """Base class for different (maybe quantized) linear methods."""

    @abstractmethod
165
166
167
168
169
170
171
172
173
174
175
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        """Create weights for a linear layer.
176
           The weights will be set as attributes of the layer.
177

178
179
180
        Args:
            layer: The layer that is using the LinearMethodBase factory.
            input_size_per_partition: Size of the weight input dim on rank X.
181
            output_partition_sizes: Sizes of the output dim of each logical
182
183
184
185
186
187
                weight on rank X. E.g., output_partition_sizes for QKVLinear
                is a list contains the width of Wq, Wk, Wv on rank X.
            input_size: Size of the input dim of the weight across all ranks.
            output_size: Size of the output dim of the weight across all ranks.
            params_dtype: Datatype of the parameters.
        """
188
189
190
        raise NotImplementedError

    @abstractmethod
191
192
193
194
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
195
        bias: torch.Tensor | None = None,
196
    ) -> torch.Tensor:
197
198
        """Apply the weights in layer to the input tensor.
        Expects create_weights to have been called before on the layer."""
199
200
201
202
        raise NotImplementedError


class UnquantizedLinearMethod(LinearMethodBase):
203
    """Linear method without quantization."""
204

205
206
207
208
209
210
211
212
213
214
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
215
216
217
218
        # This method creates unquantized linear weights.
        # The weights are not quantized, and they are not sharded.
        # The amount of memory allocated for the weights is
        # sum(output_partition_sizes) * input_size_per_partition.
219
220
221
222
223
224
225
226
227
228
229
        weight_loader = extra_weight_attrs.pop("weight_loader")
        weight = ModelWeightParameter(
            data=torch.empty(
                sum(output_partition_sizes),
                input_size_per_partition,
                dtype=params_dtype,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
230

231
232
        layer.register_parameter("weight", weight)
        set_weight_attrs(weight, extra_weight_attrs)
233

234
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
235
        if current_platform.is_cpu():
236
            from vllm.model_executor.layers.utils import dispatch_cpu_unquantized_gemm
237

238
            dispatch_cpu_unquantized_gemm(layer, remove_weight=True)
239

240
241
242
243
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
244
        bias: torch.Tensor | None = None,
245
    ) -> torch.Tensor:
246
247
248
249
250
251
        if (
            vllm_is_batch_invariant()
            and current_platform.is_cuda_alike()
            and is_layer_moe_router_gate(getattr(layer, "prefix", ""))
        ):
            return linear_batch_invariant(x, layer.weight, bias)
252
        return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
253
254


255
class LinearBase(PluggableLayer):
256
    """Base linear layer.
257
258
259
260
261
262

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
263
        quant_config: Quantization configure.
264
        prefix: Prefix for parameter names.
265
        return_bias: If true, return bias together with outputs in forward pass.
266
        disable_tp: If true, tensor parallelism will be disabled for this layer.
267
268
269
270
271
272
273
    """

    def __init__(
        self,
        input_size: int,
        output_size: int,
        skip_bias_add: bool = False,
274
275
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
276
        prefix: str = "",
277
278
        *,
        return_bias: bool = True,
279
        disable_tp: bool = False,
280
281
282
283
284
285
286
287
288
289
    ):
        super().__init__()

        # Keep input parameters
        self.input_size = input_size
        self.output_size = output_size
        self.skip_bias_add = skip_bias_add
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype
290
291
        self.quant_config = quant_config
        self.prefix = prefix
292
        self.allow_fp8_block_shape_mismatch = False
293
        if quant_config is None:
294
            self.quant_method: QuantizeMethodBase | None = UnquantizedLinearMethod()
295
        else:
296
            self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
297
        self.return_bias = return_bias
298
        self.disable_tp = disable_tp
299
300
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
301

302
    def update_param_tp_status(self):
303
304
305
306
        for param in self.parameters():
            if isinstance(param, BasevLLMParameter):
                param.tp_rank = self.tp_rank
                param.tp_size = self.tp_size
307
308


309
# --8<-- [start:replicated_linear]
310
@PluggableLayer.register("replicated_linear")
311
312
313
314
315
316
317
318
319
320
class ReplicatedLinear(LinearBase):
    """Replicated linear layer.

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        bias: If true, add bias.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
        quant_config: Quantization configure.
321
322
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
323
        return_bias: If true, return bias together with outputs in forward pass.
324
        disable_tp: Take no effect for replicated linear layers.
325
326
    """

327
328
    # --8<-- [end:replicated_linear]

329
330
331
332
333
334
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        skip_bias_add: bool = False,
335
336
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
337
338
339
        prefix: str = "",
        *,
        return_bias: bool = True,
340
        disable_tp: bool = False,
341
    ):
342
343
344
345
346
347
        # If MergedReplicatedLinear, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = self.output_sizes
        else:
            self.output_partition_sizes = [output_size]

348
349
350
351
352
353
354
355
356
357
        super().__init__(
            input_size,
            output_size,
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
358

359
360
        # All the linear layer supports quant method.
        assert self.quant_method is not None
361
362
363
364
365
366
367
368
369
        self.quant_method.create_weights(
            self,
            self.input_size,
            self.output_partition_sizes,
            self.input_size,
            self.output_size,
            self.params_dtype,
            weight_loader=self.weight_loader,
        )
370

371
372
        if bias:
            self.bias = Parameter(
373
374
375
376
377
378
379
380
381
                torch.empty(self.output_size, dtype=self.params_dtype)
            )
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
382
383
384
        else:
            self.register_parameter("bias", None)

385
386
387
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        # If the weight on disk does not have a shape, give it one
        # (such scales for AutoFp8).
388
389
390
391
392
393
394
395
396
397
398
        # Special case for GGUF

        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
            param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)

399
400
401
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)

402
403
        assert param.size() == loaded_weight.size(), (
            f"Tried to load weights of size {loaded_weight.size()}"
404
405
            f"to a parameter of size {param.size()}"
        )
406
407
        param.data.copy_(loaded_weight)

408
    def forward(
409
410
        self,
        x: torch.Tensor,
411
    ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
412
        bias = self.bias if not self.skip_bias_add else None
413
        assert self.quant_method is not None
414

415
        output = self.quant_method.apply(self, x, bias)
416

417
418
        if not self.return_bias:
            return output
419
        output_bias = self.bias if self.skip_bias_add else None
420
421
        return output, output_bias

422
423
424
425
426
427
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        return s

428

429
# --8<-- [start:column_parallel_linear]
430
@PluggableLayer.register("column_parallel_linear")
431
class ColumnParallelLinear(LinearBase):
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
    """Linear layer with column parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its second dimension as A = [A_1, ..., A_p].

    Args:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y_i = XA_i
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
448
        quant_config: Quantization configure.
449
        prefix: The name of the layer in the state dict, including all parents
450
                        (e.g. model.layers.0.qkv_proj)
451
452
        return_bias: If true, return bias together with outputs in forward pass.
        disable_tp: If true, weights matrix won't be sharded through tp rank.
453
454
    """

455
456
    # --8<-- [end:column_parallel_linear]

457
458
459
460
461
462
463
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
464
465
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
466
467
468
        prefix: str = "",
        *,
        return_bias: bool = True,
469
        disable_tp: bool = False,
470
    ):
471
        # Divide the weight matrix along the last dimension.
472
473
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
474
475
        self.input_size_per_partition = input_size
        self.output_size_per_partition = divide(output_size, self.tp_size)
476
477
478
479
        self.output_partition_sizes = [self.output_size_per_partition]
        # If QKV or MergedColumn, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = [
480
                divide(output_size, self.tp_size) for output_size in self.output_sizes
481
482
            ]

483
484
485
486
487
488
489
490
491
492
        super().__init__(
            input_size,
            output_size,
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
493

494
        self._maybe_allow_fp8_block_shape_mismatch()
495
496
497
        self.gather_output = gather_output

        assert self.quant_method is not None
498
499
        self.quant_method.create_weights(
            layer=self,
500
            input_size_per_partition=self.input_size_per_partition,
501
502
503
504
            output_partition_sizes=self.output_partition_sizes,
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
505
            weight_loader=(
506
507
508
509
510
                self.weight_loader_v2
                if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
                else self.weight_loader
            ),
        )
511
512
        if bias:
            self.bias = Parameter(
513
514
515
516
517
518
519
520
521
                torch.empty(self.output_size_per_partition, dtype=params_dtype)
            )
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
522
523
        else:
            self.register_parameter("bias", None)
524
        self.update_param_tp_status()
525

526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
    def _maybe_allow_fp8_block_shape_mismatch(self) -> None:
        quant_config = getattr(self, "quant_config", None)
        weight_block = getattr(quant_config, "weight_block_size", None)
        if (
            weight_block is None
            or len(weight_block) < 1
            or len(self.output_partition_sizes) <= 1
        ):
            return

        try:
            block_n = int(weight_block[0])
        except (ValueError, TypeError):
            return

        if block_n <= 0:
            return

        if any(size % block_n != 0 for size in self.output_partition_sizes):
            self.allow_fp8_block_shape_mismatch = True
            logger.debug(
                "Allowing FP8 block shape mismatch for %s (block_n=%d, partitions=%s)",
                getattr(self, "prefix", "<unknown>"),
                block_n,
                self.output_partition_sizes,
            )

553
554
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        output_dim = getattr(param, "output_dim", None)
555

556
557
558
559
560
561
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

562
563
564
565
566
567
568
569
        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
570
571
            final_shape = list(loaded_weight.shape)
            if output_dim is not None:
572
                assert final_shape[output_dim] % self.tp_size == 0
573
                final_shape[output_dim] = final_shape[output_dim] // self.tp_size
574
            param.materialize(final_shape, dtype=loaded_weight.dtype)
575

576
        param_data = param.data
577
        if output_dim is not None and not is_sharded_weight:
578
            shard_size = param_data.shape[output_dim]
579
            start_idx = self.tp_rank * shard_size
580
            loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
581
582
583
584
585

        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)
586

587
588
589
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

590
    def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor):
591
592
593
594
595
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)
596
597
        param.load_column_parallel_weight(loaded_weight=loaded_weight)

598
    def forward(
599
600
        self,
        input_,
601
    ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
602
603
604
        bias = self.bias if not self.skip_bias_add else None

        # Matrix multiply.
605
        assert self.quant_method is not None
606
        output_parallel = self.quant_method.apply(self, input_, bias)
607

608
        if self.gather_output and self.tp_size > 1:
609
610
611
612
            # All-gather across the partitions.
            output = tensor_model_parallel_all_gather(output_parallel)
        else:
            output = output_parallel
613

614
615
        if not self.return_bias:
            return output
616
        output_bias = self.bias if self.skip_bias_add else None
617
618
        return output, output_bias

619
620
621
622
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size_per_partition}"
        s += f", bias={self.bias is not None}"
623
        s += f", tp_size={self.tp_size}"
624
625
626
        s += f", gather_output={self.gather_output}"
        return s

627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645

class MergedColumnParallelLinear(ColumnParallelLinear):
    """Packed linear layers with column parallelism.

    Similar to ColumnParallelLinear, but the weight matrix is concatenated
    along the output dimension. When the weight matrix is loaded, the
    different partitions are sharded separately.

    Args:
        input_size: input dimension of the linear layer.
        output_sizes: list of output dimensions of the linear layer.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make the output
                       available to all GPUs, otherwise, every GPU will have
                       its own output.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
646
        quant_config: Quantization configure.
647
648
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
649
        return_bias: If true, return bias together with outputs in forward pass.
650
651
        disable_tp: If true, all weights matrix won't be sharded, this layer
                    will be treated as a "Replicated" MergedLinear.
652
653
    """

654
655
656
657
658
659
660
    def __init__(
        self,
        input_size: int,
        output_sizes: list[int],
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
661
662
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
663
664
665
        prefix: str = "",
        *,
        return_bias: bool = True,
666
        disable_tp: bool = False,
667
    ):
668
        self.output_sizes = output_sizes
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0

        assert all(output_size % self.tp_size == 0 for output_size in output_sizes)
        super().__init__(
            input_size=input_size,
            output_size=sum(output_sizes),
            bias=bias,
            gather_output=gather_output,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            quant_config=quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
James Fleming's avatar
James Fleming committed
685

686
687
688
689
    def weight_loader(
        self,
        param: Parameter,
        loaded_weight: torch.Tensor,
690
        loaded_shard_id: int | None = None,
691
    ):
692
693
694
695
696
        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
697
698
699
700
701
            if loaded_shard_id is not None:
                param.data[loaded_shard_id].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
                param.shard_weight_type = {
702
                    i: loaded_weight.item() for i, _ in enumerate(self.output_sizes)
703
                }
704
705
            return

706
707
        if is_gguf_weight:
            output_dim = getattr(param, "output_dim", None)
708
709
            shard_size = loaded_weight.size(output_dim) // self.tp_size
            start_idx = self.tp_rank * shard_size
710

711
            if loaded_shard_id is not None:
712
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
713
714
715
716
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
717

718
719
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
720
721
        # Special case for per-tensor scale to load scalar into fused array.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
722

723
        if loaded_shard_id is None:
724
725
            # Loaded weight is already fused on disk (mlp).
            # (e.g., Phi-3's gate_up_proj).
726
            if output_dim is None:
727
                if needs_scalar_to_array:
728
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
729
730
                        param_data, loaded_weight, 0
                    )
731

732
733
734
735
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            current_shard_offset = 0
736
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
737
            shard_offsets: list[tuple[int, int, int]] = []
738
739
740
741
742
            for i, output_size in enumerate(self.output_sizes):
                shard_offsets.append((i, current_shard_offset, output_size))
                current_shard_offset += output_size
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
743
                # Special case for Quantization.
744
745
746
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
747
748
                    shard_size = shard_size // param.packed_factor
                    shard_offset = shard_offset // param.packed_factor
749
                    # Special case for Marlin.
750
                    shard_size, shard_offset = adjust_marlin_shard(
751
752
                        param, shard_size, shard_offset
                    )
753

754
                if use_bitsandbytes_4bit:
755
756
757
758
759
760
761
                    index = list(itertools.accumulate([0] + self.output_sizes))
                    orig_offsets = {
                        str(i): (index[i], size)
                        for i, size in enumerate(self.output_sizes)
                    }
                    orig_offsets["total"] = (self.output_size, 0)
                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
762
763
                        param, orig_offsets, str(shard_id)
                    )
764

765
                loaded_weight_shard = loaded_weight.narrow(
766
767
                    output_dim, shard_offset, shard_size
                )
768
769
770
771
772
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id < len(self.output_sizes)
        if output_dim is not None:
773
774
775
776
777
778
779
780
781
782
783
784
            shard_offset = sum(self.output_sizes[:loaded_shard_id])
            shard_size = self.output_sizes[loaded_shard_id]

            if isinstance(param, BlockQuantScaleParameter):
                weight_block_size = getattr(self, "weight_block_size", None)
                shard_size, shard_offset = adjust_block_scale_shard(
                    weight_block_size, shard_size, shard_offset
                )

            shard_offset //= self.tp_size
            shard_size //= self.tp_size

785
            # Special case for quantization.
786
787
788
789
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
790
791
                shard_size = shard_size // param.packed_factor
                shard_offset = shard_offset // param.packed_factor
792
                # Special case for Marlin.
793
                shard_size, shard_offset = adjust_marlin_shard(
794
795
                    param, shard_size, shard_offset
                )
796

797
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
798
799
800
801
802
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

803
            if use_bitsandbytes_4bit:
804
                shard_size = loaded_weight.shape[output_dim]
805
                shard_offset = loaded_weight.shape[output_dim] * loaded_shard_id
806

807
            param_data = param_data.narrow(output_dim, shard_offset, shard_size)
808
            start_idx = self.tp_rank * shard_size
809
            if not is_sharded_weight:
810
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
811
812
813
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
814
815
                param_data, loaded_weight, loaded_shard_id
            )
816

817
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
818
819
820
821
822
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "MergedColumnParallelLinear, assume the weight is "
823
824
                    "the same for all partitions."
                )
825

826
827
828
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

829
830
831
    def _load_fused_module_from_checkpoint(
        self, param: BasevLLMParameter, loaded_weight: torch.Tensor
    ):
832
833
834
        """
        Handle special case for models where MLP layers are already
        fused on disk. In this case, we have no shard id. This function
835
        determines the shard id by splitting these layers and then calls
836
837
838
839
840
841
842
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """

        current_shard_offset = 0
843
        shard_offsets: list[tuple[int, int, int]] = []
844
845
846
847
848
849
850
851
        for i, output_size in enumerate(self.output_sizes):
            shard_offsets.append((i, current_shard_offset, output_size))
            current_shard_offset += output_size

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
852
853
854
855
856
857
858
859
860
861
862
            if (
                isinstance(param, (PackedColumnParameter, PackedvLLMParameter))
                and param.packed_dim == param.output_dim
            ):
                shard_size, shard_offset = param.adjust_shard_indexes_for_packing(
                    shard_size=shard_size, shard_offset=shard_offset
                )

            loaded_weight_shard = loaded_weight.narrow(
                param.output_dim, shard_offset, shard_size
            )
863
864
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

865
866
867
868
    def weight_loader_v2(
        self,
        param: BasevLLMParameter,
        loaded_weight: torch.Tensor,
869
        loaded_shard_id: int | None = None,
870
    ):
871
        if loaded_shard_id is None:
872
            if isinstance(param, PerTensorScaleParameter):
873
                param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0)
874
                return
875
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
876
                param.load_merged_column_weight(loaded_weight=loaded_weight)
877
                return
878
            # TODO: @dsikka - move to parameter.py
879
880
881
882
883
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id < len(self.output_sizes)

884
885
886
        shard_offset = sum(self.output_sizes[:loaded_shard_id])
        shard_size = self.output_sizes[loaded_shard_id]

887
        if isinstance(param, BlockQuantScaleParameter):
888
889
890
            weight_block_size = getattr(self, "weight_block_size", None)
            shard_size, shard_offset = adjust_block_scale_shard(
                weight_block_size, shard_size, shard_offset
891
            )
892
893
894

        shard_offset //= self.tp_size
        shard_size //= self.tp_size
895

896
897
898
899
900
901
902
        param.load_merged_column_weight(
            loaded_weight=loaded_weight,
            shard_id=loaded_shard_id,
            shard_offset=shard_offset,
            shard_size=shard_size,
            tp_rank=self.tp_rank,
        )
903

904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925

class QKVParallelLinear(ColumnParallelLinear):
    """Linear layers for the attention's QKV transformation.

    Linear layers for the linear transformation of the query, key, and value
    vectors in the attention layer. The weight matrix is concatenated along
    the output dimension. The layer is parallelized along the head dimension.
    When the number of key/value heads is smaller than the number of query
    heads (e.g., multi-query/grouped-query attention), the key/value head may
    be replicated while the query heads are partitioned.

    Args:
        hidden_size: input hidden state size of the transformer.
        head_size: size of each attention head.
        total_num_heads: total number of attention query heads.
        total_num_kv_heads: total number of attention key/value heads. If
                            None, assume total_num_kv_heads = total_num_heads.
        bias: If true, add bias.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
926
        quant_config: Quantization configure.
927
928
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
929
        return_bias: If true, return bias together with outputs in forward pass.
930
        disable_tp: If true, weights matrix won't be sharded through tp rank.
931
932
    """

933
934
935
936
937
    def __init__(
        self,
        hidden_size: int,
        head_size: int,
        total_num_heads: int,
938
        total_num_kv_heads: int | None = None,
939
940
        bias: bool = True,
        skip_bias_add: bool = False,
941
942
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
943
944
945
        prefix: str = "",
        *,
        return_bias: bool = True,
946
        disable_tp: bool = False,
947
        v_head_size: int | None = None,
948
    ):
949
950
        self.hidden_size = hidden_size
        self.head_size = head_size
951
        self.v_head_size = v_head_size if v_head_size is not None else head_size
952
953
954
955
956
        self.total_num_heads = total_num_heads
        if total_num_kv_heads is None:
            total_num_kv_heads = total_num_heads
        self.total_num_kv_heads = total_num_kv_heads
        # Divide the weight matrix along the last dimension.
957
        tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
958
959
960
        self.num_heads = divide(self.total_num_heads, tp_size)
        if tp_size >= self.total_num_kv_heads:
            self.num_kv_heads = 1
961
            self.num_kv_head_replicas = divide(tp_size, self.total_num_kv_heads)
962
963
964
965
        else:
            self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
            self.num_kv_head_replicas = 1
        input_size = self.hidden_size
966
        output_size = (
967
968
969
970
            self.num_heads * self.head_size
            + self.num_kv_heads * self.head_size
            + self.num_kv_heads * self.v_head_size
        ) * tp_size
971
972
973
        self.output_sizes = [
            self.num_heads * self.head_size * tp_size,  # q_proj
            self.num_kv_heads * self.head_size * tp_size,  # k_proj
974
            self.num_kv_heads * self.v_head_size * tp_size,  # v_proj
James Fleming's avatar
James Fleming committed
975
976
        ]

977
978
979
980
981
982
983
984
985
986
987
988
        super().__init__(
            input_size=input_size,
            output_size=output_size,
            bias=bias,
            gather_output=False,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            quant_config=quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
989

990
991
992
993
994
    def _get_shard_offset_mapping(self, loaded_shard_id: str):
        shard_offset_mapping = {
            "q": 0,
            "k": self.num_heads * self.head_size,
            "v": (self.num_heads + self.num_kv_heads) * self.head_size,
995
996
            "total": (self.num_heads + self.num_kv_heads) * self.head_size
            + self.num_kv_heads * self.v_head_size,
997
998
999
1000
1001
1002
1003
        }
        return shard_offset_mapping.get(loaded_shard_id)

    def _get_shard_size_mapping(self, loaded_shard_id: str):
        shard_size_mapping = {
            "q": self.num_heads * self.head_size,
            "k": self.num_kv_heads * self.head_size,
1004
            "v": self.num_kv_heads * self.v_head_size,
1005
1006
1007
        }
        return shard_size_mapping.get(loaded_shard_id)

1008
1009
1010
    def _load_fused_module_from_checkpoint(
        self, param: BasevLLMParameter, loaded_weight: torch.Tensor
    ):
1011
        """
1012
        Handle special case for models where QKV layers are already
1013
        fused on disk. In this case, we have no shard id. This function
1014
        determines the shard id by splitting these layers and then calls
1015
1016
1017
1018
1019
1020
1021
1022
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """
        shard_offsets = [
            # (shard_id, shard_offset, shard_size)
            ("q", 0, self.total_num_heads * self.head_size),
1023
1024
1025
1026
1027
1028
1029
1030
            (
                "k",
                self.total_num_heads * self.head_size,
                self.total_num_kv_heads * self.head_size,
            ),
            (
                "v",
                (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
1031
                self.total_num_kv_heads * self.v_head_size,
1032
            ),
1033
1034
1035
1036
1037
1038
        ]

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
            if (
                isinstance(param, (PackedColumnParameter, PackedvLLMParameter))
                and param.packed_dim == param.output_dim
            ):
                shard_size, shard_offset = param.adjust_shard_indexes_for_packing(
                    shard_size=shard_size, shard_offset=shard_offset
                )

            loaded_weight_shard = loaded_weight.narrow(
                param.output_dim, shard_offset, shard_size
            )
1050
1051
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

1052
1053
1054
1055
    def weight_loader_v2(
        self,
        param: BasevLLMParameter,
        loaded_weight: torch.Tensor,
1056
        loaded_shard_id: str | None = None,
1057
    ):
1058
        if loaded_shard_id is None:  # special case for certain models
1059
            if isinstance(param, PerTensorScaleParameter):
1060
1061
1062
                param.load_qkv_weight(
                    loaded_weight=loaded_weight, shard_id=0, tp_rank=self.tp_rank
                )
1063
                return
1064
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
1065
                param.load_qkv_weight(loaded_weight=loaded_weight, tp_rank=self.tp_rank)
1066
                return
1067
            # TODO: @dsikka - move to parameter.py
1068
1069
1070
1071
1072
1073
1074
1075
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id in ["q", "k", "v"]

        shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
        shard_size = self._get_shard_size_mapping(loaded_shard_id)

1076
        if isinstance(param, BlockQuantScaleParameter):
1077
1078
1079
1080
            weight_block_size = getattr(self, "weight_block_size", None)
            shard_size, shard_offset = adjust_block_scale_shard(
                weight_block_size, shard_size, shard_offset
            )
1081

1082
1083
1084
1085
1086
1087
1088
1089
        param.load_qkv_weight(
            loaded_weight=loaded_weight,
            num_heads=self.num_kv_head_replicas,
            shard_id=loaded_shard_id,
            shard_offset=shard_offset,
            shard_size=shard_size,
            tp_rank=self.tp_rank,
        )
1090

1091
1092
1093
1094
    def weight_loader(
        self,
        param: Parameter,
        loaded_weight: torch.Tensor,
1095
        loaded_shard_id: str | None = None,
1096
    ):
1097
1098
1099
1100
        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
1101
        if is_gguf_weight_type:
1102
            idx_map = {"q": 0, "k": 1, "v": 2}
1103
1104
1105
1106
            if loaded_shard_id is not None:
                param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
1107
                param.shard_weight_type = {k: loaded_weight.item() for k in idx_map}
1108
1109
            return

1110
1111
        if is_gguf_weight:
            output_dim = getattr(param, "output_dim", None)
1112
1113
            shard_size = loaded_weight.size(output_dim) // self.tp_size
            start_idx = self.tp_rank * shard_size
1114

1115
            if loaded_shard_id is not None:
1116
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
1117
1118
1119
1120
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
1121

1122
1123
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
1124

1125
1126
        # Special case for per-tensor scales in fused case.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
1127

1128
        if loaded_shard_id is None:
1129
1130
            # Loaded weight is already fused on disk (qkv).
            # (e.g., Phi-3's qkv_proj).
1131
            if output_dim is None:
1132
                if needs_scalar_to_array:
1133
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
1134
1135
                        param_data, loaded_weight, 0
                    )
1136

1137
1138
1139
1140
1141
1142
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            shard_offsets = [
                # (shard_id, shard_offset, shard_size)
                ("q", 0, self.total_num_heads * self.head_size),
1143
1144
1145
1146
1147
1148
1149
1150
                (
                    "k",
                    self.total_num_heads * self.head_size,
                    self.total_num_kv_heads * self.head_size,
                ),
                (
                    "v",
                    (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
1151
                    self.total_num_kv_heads * self.v_head_size,
1152
                ),
1153
            ]
1154
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1155

1156
1157
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
1158
                # Special case for Quantized Weights.
1159
1160
1161
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
1162
1163
                    shard_size = shard_size // param.packed_factor
                    shard_offset = shard_offset // param.packed_factor
1164

1165
                    # Special case for Marlin.
1166
                    shard_size, shard_offset = adjust_marlin_shard(
1167
1168
                        param, shard_size, shard_offset
                    )
1169

1170
1171
1172
                if use_bitsandbytes_4bit:
                    orig_qkv_offsets = {
                        "q": (0, self.total_num_heads * self.head_size),
1173
1174
1175
1176
1177
1178
1179
                        "k": (
                            self.total_num_heads * self.head_size,
                            self.total_num_kv_heads * self.head_size,
                        ),
                        "v": (
                            (self.total_num_heads + self.total_num_kv_heads)
                            * self.head_size,
1180
                            self.total_num_kv_heads * self.v_head_size,
1181
1182
                        ),
                        "total": (
1183
1184
1185
                            (self.total_num_heads + self.total_num_kv_heads)
                            * self.head_size
                            + self.total_num_kv_heads * self.v_head_size,
1186
1187
                            0,
                        ),
1188
1189
1190
                    }

                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
1191
1192
                        param, orig_qkv_offsets, shard_id
                    )
1193

1194
                loaded_weight_shard = loaded_weight.narrow(
1195
1196
                    output_dim, shard_offset, shard_size
                )
1197
1198
1199
1200
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id in ["q", "k", "v"]
1201
1202

        # If output dim is defined, use the default loading process.
1203
1204
1205
1206
1207
1208
1209
1210
        if output_dim is not None:
            if loaded_shard_id == "q":
                shard_offset = 0
                shard_size = self.num_heads * self.head_size
            elif loaded_shard_id == "k":
                shard_offset = self.num_heads * self.head_size
                shard_size = self.num_kv_heads * self.head_size
            elif loaded_shard_id == "v":
1211
                shard_offset = (self.num_heads + self.num_kv_heads) * self.head_size
1212
                shard_size = self.num_kv_heads * self.v_head_size
1213
1214
1215
1216
1217
1218
1219

            if isinstance(param, BlockQuantScaleParameter):
                weight_block_size = getattr(self, "weight_block_size", None)
                shard_size, shard_offset = adjust_block_scale_shard(
                    weight_block_size, shard_size, shard_offset
                )

1220
            # Special case for Quantized Weights.
1221
1222
1223
1224
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
1225
1226
                shard_size = shard_size // param.packed_factor
                shard_offset = shard_offset // param.packed_factor
1227

1228
                # Special case for Marlin.
1229
                shard_size, shard_offset = adjust_marlin_shard(
1230
1231
                    param, shard_size, shard_offset
                )
1232

1233
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1234
1235
1236
1237
1238
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

1239
            if use_bitsandbytes_4bit:
1240
1241
                orig_qkv_offsets = {
                    "q": (0, self.num_heads * self.head_size),
1242
1243
1244
1245
1246
1247
                    "k": (
                        self.num_heads * self.head_size,
                        self.num_kv_heads * self.head_size,
                    ),
                    "v": (
                        (self.num_heads + self.num_kv_heads) * self.head_size,
1248
                        self.num_kv_heads * self.v_head_size,
1249
1250
                    ),
                    "total": (
1251
1252
                        (self.num_heads + self.num_kv_heads) * self.head_size
                        + self.num_kv_heads * self.v_head_size,
1253
1254
                        0,
                    ),
1255
                }
1256
                shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
1257
1258
                    param, orig_qkv_offsets, loaded_shard_id
                )
1259

1260
            param_data = param_data.narrow(output_dim, shard_offset, shard_size)
1261
            if loaded_shard_id == "q":
1262
                shard_rank = self.tp_rank
1263
            else:
1264
1265
                shard_rank = self.tp_rank // self.num_kv_head_replicas
            start_idx = shard_rank * shard_size
1266

1267
            if not is_sharded_weight:
1268
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
1269

1270
1271
1272
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
1273
1274
                param_data, loaded_weight, loaded_shard_id
            )
1275
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
1276
1277
1278
1279
1280
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "QKVParallelLinear, assume the weight is the same "
1281
1282
                    "for all partitions."
                )
1283

1284
1285
1286
1287
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)


1288
# --8<-- [start:row_parallel_linear]
1289
@PluggableLayer.register("row_parallel_linear")
1290
class RowParallelLinear(LinearBase):
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
    """Linear layer with row parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its first dimension and X along its second dimension as:
               -   -
              | A_1 |
              | .   |
          A = | .   |        X = [X_1, ..., X_p]
              | .   |
              | A_p |
               -   -
    Arguments:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias. Note that bias is not parallelized.
        input_is_parallel: If true, we assume that the input is already
                           split across the GPUs and we do not split
                           again.
        skip_bias_add: This was added to enable performance optimization where
                       bias can be fused with other element-wise operations.
                       We skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
1313
1314
1315
        reduce_results: If true, call all-reduce on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y = X_iA_i
1316
        quant_config: Quantization configure.
1317
1318
1319
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.down_proj)
        return_bias: If true, return bias together with outputs in forward pass.
1320
        disable_tp: If true, weights matrix won't be sharded through tp rank.
1321
1322
    """

1323
1324
    # --8<-- [end:row_parallel_linear]

1325
1326
1327
1328
1329
1330
1331
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        input_is_parallel: bool = True,
        skip_bias_add: bool = False,
1332
        params_dtype: torch.dtype | None = None,
1333
        reduce_results: bool = True,
1334
        quant_config: QuantizationConfig | None = None,
1335
1336
1337
        prefix: str = "",
        *,
        return_bias: bool = True,
1338
        disable_tp: bool = False,
1339
    ):
1340
        # Divide the weight matrix along the first dimension.
1341
1342
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
1343
1344
1345
1346
        self.input_size_per_partition = divide(input_size, self.tp_size)
        self.output_size_per_partition = output_size
        self.output_partition_sizes = [output_size]

1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
        super().__init__(
            input_size,
            output_size,
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
1357

1358
1359
1360
        self.input_is_parallel = input_is_parallel
        self.reduce_results = reduce_results

1361
        assert self.quant_method is not None
1362
1363
1364
        self.quant_method.create_weights(
            layer=self,
            input_size_per_partition=self.input_size_per_partition,
1365
            output_partition_sizes=self.output_partition_sizes,
1366
1367
1368
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
1369
            weight_loader=(
1370
1371
1372
1373
1374
                self.weight_loader_v2
                if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
                else self.weight_loader
            ),
        )
1375
        if not reduce_results and (bias and not skip_bias_add):
1376
1377
1378
1379
            raise ValueError(
                "When not reduce the results, adding bias to the "
                "results can lead to incorrect results"
            )
1380
1381

        if bias:
1382
1383
1384
1385
1386
1387
1388
1389
            self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype))
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
1390
1391
        else:
            self.register_parameter("bias", None)
1392
        self.update_param_tp_status()
1393
1394
1395

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        input_dim = getattr(param, "input_dim", None)
1396
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1397
1398
1399
1400
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
1401
1402
1403
1404
1405
1406
1407
1408
1409

        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
1410
1411
            weight_shape = list(loaded_weight.shape)
            if input_dim:
1412
                weight_shape[input_dim] = weight_shape[input_dim] // self.tp_size
1413
            param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
1414

1415
        param_data = param.data
1416
        if input_dim is not None and not is_sharded_weight:
1417
            shard_size = param_data.shape[input_dim]
1418
            start_idx = self.tp_rank * shard_size
1419
            loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)
1420

1421
1422
1423
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
1424
1425
            loaded_weight = loaded_weight.reshape(1)

1426
1427
1428
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

1429
    def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor):
1430
1431
1432
1433
1434
1435
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)

1436
1437
        param.load_row_parallel_weight(loaded_weight=loaded_weight)

1438
    def forward(
1439
1440
        self,
        input_,
1441
    ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
1442
1443
1444
1445
        if self.input_is_parallel:
            input_parallel = input_
        else:
            splitted_input = split_tensor_along_last_dim(
1446
1447
                input_, num_partitions=self.tp_size
            )
1448
            input_parallel = splitted_input[self.tp_rank].contiguous()
1449
1450

        # Matrix multiply.
1451
        assert self.quant_method is not None
1452
1453
1454
        # Only fuse bias add into GEMM for rank 0 (this ensures that
        # bias will not get added more than once in TP>1 case)
        bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
1455
1456
        output_parallel = self.quant_method.apply(self, input_parallel, bias_)

1457
        if self.reduce_results and self.tp_size > 1:
1458
            output = tensor_model_parallel_all_reduce(output_parallel)
1459
        else:
1460
1461
            output = output_parallel

1462
1463
        if not self.return_bias:
            return output
1464
        output_bias = self.bias if self.skip_bias_add else None
1465
        return output, output_bias
1466
1467

    def extra_repr(self) -> str:
1468
        s = f"in_features={self.input_size_per_partition}"
1469
1470
1471
1472
1473
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={self.tp_size}"
        s += f", reduce_results={self.reduce_results}"
        return s