linear.py 56.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
from abc import abstractmethod
6
from typing import Any
7
8

import torch
9
from torch.nn.parameter import Parameter, UninitializedParameter
10

11
12
13
14
15
16
17
18
from vllm.distributed import (
    divide,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
    split_tensor_along_last_dim,
    tensor_model_parallel_all_gather,
    tensor_model_parallel_all_reduce,
)
19
from vllm.logger import init_logger
20
from vllm.model_executor.custom_op import CustomOp
21
from vllm.model_executor.layers.quantization.base_config import (
22
23
24
    QuantizationConfig,
    QuantizeMethodBase,
)
25
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
26
27
28
29
30
31
32
33
34
from vllm.model_executor.parameter import (
    BasevLLMParameter,
    BlockQuantScaleParameter,
    ModelWeightParameter,
    PackedColumnParameter,
    PackedvLLMParameter,
    PerTensorScaleParameter,
    RowvLLMParameter,
)
35
from vllm.model_executor.utils import set_weight_attrs
36
from vllm.platforms import current_platform
37
38
39

logger = init_logger(__name__)

40
WEIGHT_LOADER_V2_SUPPORTED = [
41
    "UnquantizedLinearMethod",
42
    "CompressedTensorsLinearMethod",
43
    "CompressedTensorsLinearTransformMethod",
44
45
46
47
48
49
50
51
52
53
    "AWQMarlinLinearMethod",
    "AWQLinearMethod",
    "GPTQMarlinLinearMethod",
    "Fp8LinearMethod",
    "MarlinLinearMethod",
    "GPTQMarlin24LinearMethod",
    "TPUInt8LinearMethod",
    "GPTQLinearMethod",
    "FBGEMMFp8LinearMethod",
    "ModelOptFp8LinearMethod",
54
55
    "ModelOptFp8PcPtLinearMethod",
    "ModelOptFp8PbWoLinearMethod",
56
57
58
59
    "IPEXAWQLinearMethod",
    "IPEXGPTQLinearMethod",
    "QuarkLinearMethod",
    "ModelOptNvFp4LinearMethod",
60
    "PetitNvFp4LinearMethod",
61
]
62

63

64
65
66
67
68
69
70
71
def adjust_marlin_shard(param, shard_size, shard_offset):
    marlin_tile_size = getattr(param, "marlin_tile_size", None)
    if marlin_tile_size is None:
        return shard_size, shard_offset

    return shard_size * marlin_tile_size, shard_offset * marlin_tile_size


72
73
74
75
76
77
78
79
def adjust_block_scale_shard(weight_block_size, shard_size, shard_offset):
    assert weight_block_size is not None
    block_n = weight_block_size[0]
    shard_offset = (shard_offset + block_n - 1) // block_n
    shard_size = (shard_size + block_n - 1) // block_n
    return shard_size, shard_offset


80
81
82
def adjust_bitsandbytes_4bit_shard(
    param: Parameter, shard_offsets: dict[str, tuple[int, int]], loaded_shard_id: str
) -> tuple[int, int]:
83
84
    """Adjust the quantization offsets and sizes for BitsAndBytes sharding."""

85
86
    total, _ = shard_offsets["total"]
    orig_offset, orig_size = shard_offsets[loaded_shard_id]
87
88
89
90
91
92
93
94

    quantized_total = param.data.shape[0]
    quantized_offset = orig_offset * quantized_total // total
    quantized_size = orig_size * quantized_total // total

    return quantized_size, quantized_offset


95
96
97
def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
    """For fused modules (QKV and MLP) we have an array of length
    N that holds 1 scale for each "logical" matrix. So the param
98
99
    is an array of length N. The loaded_weight corresponds to
    one of the shards on disk. Here, we slice the param based on
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
    the shard_id for loading.
    """
    qkv_idxs = {"q": 0, "k": 1, "v": 2}

    if isinstance(shard_id, str):
        shard_id = qkv_idxs[shard_id]
    elif not isinstance(shard_id, int):
        raise ValueError(f"Unknown Shard Id {shard_id}")

    # AutoFP8 scales do not have a shape
    # compressed-tensors scales do have a shape
    if len(loaded_weight.shape) != 0:
        assert loaded_weight.shape[0] == 1
        loaded_weight = loaded_weight[0]

    return param[shard_id], loaded_weight


118
119
120
121
122
123
124
125
# TODO(Isotr0py): We might need a more flexible structure to handle
# bitsandbytes shard offsets.
def left_shift_bitsandbytes_4bit_shard(bnb_weight_attrs: dict[str, Any]):
    """
    Separate the BitsAndBytes 4-bit shard.

    For example, given bnb weight attributes as below:
    {
126
        'bnb_shard_offsets': array([0, 4, 8, 16]),
127
128
129
130
131
        'bnb_quant_state': {0: ..., 1: ..., 2: ...},
    }

    The function will return:
    {
132
        'bnb_shard_offsets': array([0, 4]),
133
134
135
136
137
138
139
140
141
142
143
144
145
146
        'bnb_quant_state': {0: ...},
    }
    and
    {
        'bnb_shard_offsets': array([0, 4, 12]),
        'bnb_quant_state': {0: ..., 1: ...},
    }
    """
    shard_offsets = bnb_weight_attrs["bnb_shard_offsets"]
    offset_l = shard_offsets[:2]
    offset_r = shard_offsets[1:] - shard_offsets[1]
    quant_state_l = {0: bnb_weight_attrs["bnb_quant_state"][0]}
    quant_state_r = {
        i - 1: bnb_weight_attrs["bnb_quant_state"][i]
147
        for i in range(1, len(shard_offsets) - 1)
148
149
150
151
152
153
    }
    left = dict(bnb_shard_offsets=offset_l, bnb_quant_state=quant_state_l)
    right = dict(bnb_shard_offsets=offset_r, bnb_quant_state=quant_state_r)
    return left, right


154
class LinearMethodBase(QuantizeMethodBase):
155
156
157
    """Base class for different (maybe quantized) linear methods."""

    @abstractmethod
158
159
160
161
162
163
164
165
166
167
168
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        """Create weights for a linear layer.
169
           The weights will be set as attributes of the layer.
170

171
172
173
        Args:
            layer: The layer that is using the LinearMethodBase factory.
            input_size_per_partition: Size of the weight input dim on rank X.
174
            output_partition_sizes: Sizes of the output dim of each logical
175
176
177
178
179
180
                weight on rank X. E.g., output_partition_sizes for QKVLinear
                is a list contains the width of Wq, Wk, Wv on rank X.
            input_size: Size of the input dim of the weight across all ranks.
            output_size: Size of the output dim of the weight across all ranks.
            params_dtype: Datatype of the parameters.
        """
181
182
183
        raise NotImplementedError

    @abstractmethod
184
185
186
187
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
188
        bias: torch.Tensor | None = None,
189
    ) -> torch.Tensor:
190
191
        """Apply the weights in layer to the input tensor.
        Expects create_weights to have been called before on the layer."""
192
193
194
195
        raise NotImplementedError


class UnquantizedLinearMethod(LinearMethodBase):
196
    """Linear method without quantization."""
197

198
199
200
201
202
203
204
205
206
207
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
208
209
210
211
        # This method creates unquantized linear weights.
        # The weights are not quantized, and they are not sharded.
        # The amount of memory allocated for the weights is
        # sum(output_partition_sizes) * input_size_per_partition.
212
213
214
215
216
217
218
219
220
221
222
        weight_loader = extra_weight_attrs.pop("weight_loader")
        weight = ModelWeightParameter(
            data=torch.empty(
                sum(output_partition_sizes),
                input_size_per_partition,
                dtype=params_dtype,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
223

224
225
        layer.register_parameter("weight", weight)
        set_weight_attrs(weight, extra_weight_attrs)
226

227
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
228
        if current_platform.is_cpu():
229
            from vllm.model_executor.layers.utils import dispatch_cpu_unquantized_gemm
230

231
            dispatch_cpu_unquantized_gemm(layer, remove_weight=True)
232

233
234
235
236
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
237
        bias: torch.Tensor | None = None,
238
    ) -> torch.Tensor:
239
        return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
240
241


242
class LinearBase(CustomOp):
243
    """Base linear layer.
244
245
246
247
248
249

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
250
        quant_config: Quantization configure.
251
        prefix: Prefix for parameter names.
252
        return_bias: If true, return bias together with outputs in forward pass.
253
        disable_tp: If true, tensor parallelism will be disabled for this layer.
254
255
256
257
258
259
260
    """

    def __init__(
        self,
        input_size: int,
        output_size: int,
        skip_bias_add: bool = False,
261
262
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
263
        prefix: str = "",
264
265
        *,
        return_bias: bool = True,
266
        disable_tp: bool = False,
267
268
269
270
271
272
273
274
275
276
    ):
        super().__init__()

        # Keep input parameters
        self.input_size = input_size
        self.output_size = output_size
        self.skip_bias_add = skip_bias_add
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype
277
278
        self.quant_config = quant_config
        self.prefix = prefix
279
        self.allow_fp8_block_shape_mismatch = False
280
        if quant_config is None:
281
            self.quant_method: QuantizeMethodBase | None = UnquantizedLinearMethod()
282
        else:
283
            self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
284
        self.return_bias = return_bias
285
        self.disable_tp = disable_tp
286
287
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
288

289
    def update_param_tp_status(self):
290
291
292
293
        for param in self.parameters():
            if isinstance(param, BasevLLMParameter):
                param.tp_rank = self.tp_rank
                param.tp_size = self.tp_size
294
295


296
# --8<-- [start:replicated_linear]
297
@CustomOp.register("replicated_linear")
298
299
300
301
302
303
304
305
306
307
class ReplicatedLinear(LinearBase):
    """Replicated linear layer.

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        bias: If true, add bias.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
        quant_config: Quantization configure.
308
309
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
310
        return_bias: If true, return bias together with outputs in forward pass.
311
        disable_tp: Take no effect for replicated linear layers.
312
313
    """

314
315
    # --8<-- [end:replicated_linear]

316
317
318
319
320
321
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        skip_bias_add: bool = False,
322
323
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
324
325
326
        prefix: str = "",
        *,
        return_bias: bool = True,
327
        disable_tp: bool = False,
328
    ):
329
330
331
332
333
334
        # If MergedReplicatedLinear, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = self.output_sizes
        else:
            self.output_partition_sizes = [output_size]

335
336
337
338
339
340
341
342
343
344
        super().__init__(
            input_size,
            output_size,
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
345

346
347
        # All the linear layer supports quant method.
        assert self.quant_method is not None
348
349
350
351
352
353
354
355
356
        self.quant_method.create_weights(
            self,
            self.input_size,
            self.output_partition_sizes,
            self.input_size,
            self.output_size,
            self.params_dtype,
            weight_loader=self.weight_loader,
        )
357

358
359
        if bias:
            self.bias = Parameter(
360
361
362
363
364
365
366
367
368
                torch.empty(self.output_size, dtype=self.params_dtype)
            )
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
369
370
371
        else:
            self.register_parameter("bias", None)

372
373
374
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        # If the weight on disk does not have a shape, give it one
        # (such scales for AutoFp8).
375
376
377
378
379
380
381
382
383
384
385
        # Special case for GGUF

        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
            param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)

386
387
388
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)

389
390
        assert param.size() == loaded_weight.size(), (
            f"Tried to load weights of size {loaded_weight.size()}"
391
392
            f"to a parameter of size {param.size()}"
        )
393
394
        param.data.copy_(loaded_weight)

395
    def forward(
396
397
        self,
        x: torch.Tensor,
398
    ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
399
        bias = self.bias if not self.skip_bias_add else None
400
        assert self.quant_method is not None
401

402
        output = self.quant_method.apply(self, x, bias)
403

404
405
        if not self.return_bias:
            return output
406
        output_bias = self.bias if self.skip_bias_add else None
407
408
        return output, output_bias

409
410
411
412
413
414
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        return s

415

416
# --8<-- [start:column_parallel_linear]
417
@CustomOp.register("column_parallel_linear")
418
class ColumnParallelLinear(LinearBase):
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
    """Linear layer with column parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its second dimension as A = [A_1, ..., A_p].

    Args:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y_i = XA_i
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
435
        quant_config: Quantization configure.
436
        prefix: The name of the layer in the state dict, including all parents
437
                        (e.g. model.layers.0.qkv_proj)
438
439
        return_bias: If true, return bias together with outputs in forward pass.
        disable_tp: If true, weights matrix won't be sharded through tp rank.
440
441
    """

442
443
    # --8<-- [end:column_parallel_linear]

444
445
446
447
448
449
450
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
451
452
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
453
454
455
        prefix: str = "",
        *,
        return_bias: bool = True,
456
        disable_tp: bool = False,
457
    ):
458
        # Divide the weight matrix along the last dimension.
459
460
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
461
462
        self.input_size_per_partition = input_size
        self.output_size_per_partition = divide(output_size, self.tp_size)
463
464
465
466
        self.output_partition_sizes = [self.output_size_per_partition]
        # If QKV or MergedColumn, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = [
467
                divide(output_size, self.tp_size) for output_size in self.output_sizes
468
469
            ]

470
471
472
473
474
475
476
477
478
479
        super().__init__(
            input_size,
            output_size,
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
480

481
        self._maybe_allow_fp8_block_shape_mismatch()
482
483
484
        self.gather_output = gather_output

        assert self.quant_method is not None
485
486
        self.quant_method.create_weights(
            layer=self,
487
            input_size_per_partition=self.input_size_per_partition,
488
489
490
491
            output_partition_sizes=self.output_partition_sizes,
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
492
            weight_loader=(
493
494
495
496
497
                self.weight_loader_v2
                if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
                else self.weight_loader
            ),
        )
498
499
        if bias:
            self.bias = Parameter(
500
501
502
503
504
505
506
507
508
                torch.empty(self.output_size_per_partition, dtype=params_dtype)
            )
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
509
510
        else:
            self.register_parameter("bias", None)
511
        self.update_param_tp_status()
512

513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
    def _maybe_allow_fp8_block_shape_mismatch(self) -> None:
        quant_config = getattr(self, "quant_config", None)
        weight_block = getattr(quant_config, "weight_block_size", None)
        if (
            weight_block is None
            or len(weight_block) < 1
            or len(self.output_partition_sizes) <= 1
        ):
            return

        try:
            block_n = int(weight_block[0])
        except (ValueError, TypeError):
            return

        if block_n <= 0:
            return

        if any(size % block_n != 0 for size in self.output_partition_sizes):
            self.allow_fp8_block_shape_mismatch = True
            logger.debug(
                "Allowing FP8 block shape mismatch for %s (block_n=%d, partitions=%s)",
                getattr(self, "prefix", "<unknown>"),
                block_n,
                self.output_partition_sizes,
            )

540
541
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        output_dim = getattr(param, "output_dim", None)
542

543
544
545
546
547
548
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

549
550
551
552
553
554
555
556
        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
557
558
            final_shape = list(loaded_weight.shape)
            if output_dim is not None:
559
                assert final_shape[output_dim] % self.tp_size == 0
560
                final_shape[output_dim] = final_shape[output_dim] // self.tp_size
561
            param.materialize(final_shape, dtype=loaded_weight.dtype)
562

563
        param_data = param.data
564
        if output_dim is not None and not is_sharded_weight:
565
            shard_size = param_data.shape[output_dim]
566
            start_idx = self.tp_rank * shard_size
567
            loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
568
569
570
571
572

        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)
573

574
575
576
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

577
    def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor):
578
579
580
581
582
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)
583
584
        param.load_column_parallel_weight(loaded_weight=loaded_weight)

585
    def forward(
586
587
        self,
        input_,
588
    ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
589
590
591
        bias = self.bias if not self.skip_bias_add else None

        # Matrix multiply.
592
        assert self.quant_method is not None
593
        output_parallel = self.quant_method.apply(self, input_, bias)
594

595
        if self.gather_output and self.tp_size > 1:
596
597
598
599
            # All-gather across the partitions.
            output = tensor_model_parallel_all_gather(output_parallel)
        else:
            output = output_parallel
600

601
602
        if not self.return_bias:
            return output
603
        output_bias = self.bias if self.skip_bias_add else None
604
605
        return output, output_bias

606
607
608
609
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size_per_partition}"
        s += f", bias={self.bias is not None}"
610
        s += f", tp_size={self.tp_size}"
611
612
613
        s += f", gather_output={self.gather_output}"
        return s

614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632

class MergedColumnParallelLinear(ColumnParallelLinear):
    """Packed linear layers with column parallelism.

    Similar to ColumnParallelLinear, but the weight matrix is concatenated
    along the output dimension. When the weight matrix is loaded, the
    different partitions are sharded separately.

    Args:
        input_size: input dimension of the linear layer.
        output_sizes: list of output dimensions of the linear layer.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make the output
                       available to all GPUs, otherwise, every GPU will have
                       its own output.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
633
        quant_config: Quantization configure.
634
635
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
636
        return_bias: If true, return bias together with outputs in forward pass.
637
638
        disable_tp: If true, all weights matrix won't be sharded, this layer
                    will be treated as a "Replicated" MergedLinear.
639
640
    """

641
642
643
644
645
646
647
    def __init__(
        self,
        input_size: int,
        output_sizes: list[int],
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
648
649
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
650
651
652
        prefix: str = "",
        *,
        return_bias: bool = True,
653
        disable_tp: bool = False,
654
    ):
655
        self.output_sizes = output_sizes
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0

        assert all(output_size % self.tp_size == 0 for output_size in output_sizes)
        super().__init__(
            input_size=input_size,
            output_size=sum(output_sizes),
            bias=bias,
            gather_output=gather_output,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            quant_config=quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
James Fleming's avatar
James Fleming committed
672

673
674
675
676
    def weight_loader(
        self,
        param: Parameter,
        loaded_weight: torch.Tensor,
677
        loaded_shard_id: int | None = None,
678
    ):
679
680
681
682
683
        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
684
685
686
687
688
            if loaded_shard_id is not None:
                param.data[loaded_shard_id].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
                param.shard_weight_type = {
689
                    i: loaded_weight.item() for i, _ in enumerate(self.output_sizes)
690
                }
691
692
            return

693
694
        if is_gguf_weight:
            output_dim = getattr(param, "output_dim", None)
695
696
            shard_size = loaded_weight.size(output_dim) // self.tp_size
            start_idx = self.tp_rank * shard_size
697

698
            if loaded_shard_id is not None:
699
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
700
701
702
703
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
704

705
706
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
707
708
        # Special case for per-tensor scale to load scalar into fused array.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
709

710
        if loaded_shard_id is None:
711
712
            # Loaded weight is already fused on disk (mlp).
            # (e.g., Phi-3's gate_up_proj).
713
            if output_dim is None:
714
                if needs_scalar_to_array:
715
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
716
717
                        param_data, loaded_weight, 0
                    )
718

719
720
721
722
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            current_shard_offset = 0
723
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
724
            shard_offsets: list[tuple[int, int, int]] = []
725
726
727
728
729
            for i, output_size in enumerate(self.output_sizes):
                shard_offsets.append((i, current_shard_offset, output_size))
                current_shard_offset += output_size
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
730
                # Special case for Quantization.
731
732
733
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
734
735
                    shard_size = shard_size // param.packed_factor
                    shard_offset = shard_offset // param.packed_factor
736
                    # Special case for Marlin.
737
                    shard_size, shard_offset = adjust_marlin_shard(
738
739
                        param, shard_size, shard_offset
                    )
740

741
                if use_bitsandbytes_4bit:
742
743
744
745
746
747
748
                    index = list(itertools.accumulate([0] + self.output_sizes))
                    orig_offsets = {
                        str(i): (index[i], size)
                        for i, size in enumerate(self.output_sizes)
                    }
                    orig_offsets["total"] = (self.output_size, 0)
                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
749
750
                        param, orig_offsets, str(shard_id)
                    )
751

752
                loaded_weight_shard = loaded_weight.narrow(
753
754
                    output_dim, shard_offset, shard_size
                )
755
756
757
758
759
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id < len(self.output_sizes)
        if output_dim is not None:
760
761
762
763
764
765
766
767
768
769
770
771
            shard_offset = sum(self.output_sizes[:loaded_shard_id])
            shard_size = self.output_sizes[loaded_shard_id]

            if isinstance(param, BlockQuantScaleParameter):
                weight_block_size = getattr(self, "weight_block_size", None)
                shard_size, shard_offset = adjust_block_scale_shard(
                    weight_block_size, shard_size, shard_offset
                )

            shard_offset //= self.tp_size
            shard_size //= self.tp_size

772
            # Special case for quantization.
773
774
775
776
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
777
778
                shard_size = shard_size // param.packed_factor
                shard_offset = shard_offset // param.packed_factor
779
                # Special case for Marlin.
780
                shard_size, shard_offset = adjust_marlin_shard(
781
782
                    param, shard_size, shard_offset
                )
783

784
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
785
786
787
788
789
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

790
            if use_bitsandbytes_4bit:
791
                shard_size = loaded_weight.shape[output_dim]
792
                shard_offset = loaded_weight.shape[output_dim] * loaded_shard_id
793

794
            param_data = param_data.narrow(output_dim, shard_offset, shard_size)
795
            start_idx = self.tp_rank * shard_size
796
            if not is_sharded_weight:
797
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
798
799
800
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
801
802
                param_data, loaded_weight, loaded_shard_id
            )
803

804
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
805
806
807
808
809
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "MergedColumnParallelLinear, assume the weight is "
810
811
                    "the same for all partitions."
                )
812

813
814
815
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

816
817
818
    def _load_fused_module_from_checkpoint(
        self, param: BasevLLMParameter, loaded_weight: torch.Tensor
    ):
819
820
821
        """
        Handle special case for models where MLP layers are already
        fused on disk. In this case, we have no shard id. This function
822
        determines the shard id by splitting these layers and then calls
823
824
825
826
827
828
829
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """

        current_shard_offset = 0
830
        shard_offsets: list[tuple[int, int, int]] = []
831
832
833
834
835
836
837
838
        for i, output_size in enumerate(self.output_sizes):
            shard_offsets.append((i, current_shard_offset, output_size))
            current_shard_offset += output_size

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
839
840
841
842
843
844
845
846
847
848
849
            if (
                isinstance(param, (PackedColumnParameter, PackedvLLMParameter))
                and param.packed_dim == param.output_dim
            ):
                shard_size, shard_offset = param.adjust_shard_indexes_for_packing(
                    shard_size=shard_size, shard_offset=shard_offset
                )

            loaded_weight_shard = loaded_weight.narrow(
                param.output_dim, shard_offset, shard_size
            )
850
851
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

852
853
854
855
    def weight_loader_v2(
        self,
        param: BasevLLMParameter,
        loaded_weight: torch.Tensor,
856
        loaded_shard_id: int | None = None,
857
    ):
858
        if loaded_shard_id is None:
859
            if isinstance(param, PerTensorScaleParameter):
860
                param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0)
861
                return
862
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
863
                param.load_merged_column_weight(loaded_weight=loaded_weight)
864
                return
865
            # TODO: @dsikka - move to parameter.py
866
867
868
869
870
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id < len(self.output_sizes)

871
872
873
        shard_offset = sum(self.output_sizes[:loaded_shard_id])
        shard_size = self.output_sizes[loaded_shard_id]

874
        if isinstance(param, BlockQuantScaleParameter):
875
876
877
            weight_block_size = getattr(self, "weight_block_size", None)
            shard_size, shard_offset = adjust_block_scale_shard(
                weight_block_size, shard_size, shard_offset
878
            )
879
880
881

        shard_offset //= self.tp_size
        shard_size //= self.tp_size
882

883
884
885
886
887
888
889
        param.load_merged_column_weight(
            loaded_weight=loaded_weight,
            shard_id=loaded_shard_id,
            shard_offset=shard_offset,
            shard_size=shard_size,
            tp_rank=self.tp_rank,
        )
890

891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912

class QKVParallelLinear(ColumnParallelLinear):
    """Linear layers for the attention's QKV transformation.

    Linear layers for the linear transformation of the query, key, and value
    vectors in the attention layer. The weight matrix is concatenated along
    the output dimension. The layer is parallelized along the head dimension.
    When the number of key/value heads is smaller than the number of query
    heads (e.g., multi-query/grouped-query attention), the key/value head may
    be replicated while the query heads are partitioned.

    Args:
        hidden_size: input hidden state size of the transformer.
        head_size: size of each attention head.
        total_num_heads: total number of attention query heads.
        total_num_kv_heads: total number of attention key/value heads. If
                            None, assume total_num_kv_heads = total_num_heads.
        bias: If true, add bias.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
913
        quant_config: Quantization configure.
914
915
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
916
        return_bias: If true, return bias together with outputs in forward pass.
917
        disable_tp: If true, weights matrix won't be sharded through tp rank.
918
919
    """

920
921
922
923
924
    def __init__(
        self,
        hidden_size: int,
        head_size: int,
        total_num_heads: int,
925
        total_num_kv_heads: int | None = None,
926
927
        bias: bool = True,
        skip_bias_add: bool = False,
928
929
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
930
931
932
        prefix: str = "",
        *,
        return_bias: bool = True,
933
        disable_tp: bool = False,
934
        v_head_size: int | None = None,
935
    ):
936
937
        self.hidden_size = hidden_size
        self.head_size = head_size
938
        self.v_head_size = v_head_size if v_head_size is not None else head_size
939
940
941
942
943
        self.total_num_heads = total_num_heads
        if total_num_kv_heads is None:
            total_num_kv_heads = total_num_heads
        self.total_num_kv_heads = total_num_kv_heads
        # Divide the weight matrix along the last dimension.
944
        tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
945
946
947
        self.num_heads = divide(self.total_num_heads, tp_size)
        if tp_size >= self.total_num_kv_heads:
            self.num_kv_heads = 1
948
            self.num_kv_head_replicas = divide(tp_size, self.total_num_kv_heads)
949
950
951
952
        else:
            self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
            self.num_kv_head_replicas = 1
        input_size = self.hidden_size
953
        output_size = (
954
955
956
957
            self.num_heads * self.head_size
            + self.num_kv_heads * self.head_size
            + self.num_kv_heads * self.v_head_size
        ) * tp_size
958
959
960
        self.output_sizes = [
            self.num_heads * self.head_size * tp_size,  # q_proj
            self.num_kv_heads * self.head_size * tp_size,  # k_proj
961
            self.num_kv_heads * self.v_head_size * tp_size,  # v_proj
James Fleming's avatar
James Fleming committed
962
963
        ]

964
965
966
967
968
969
970
971
972
973
974
975
        super().__init__(
            input_size=input_size,
            output_size=output_size,
            bias=bias,
            gather_output=False,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            quant_config=quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
976

977
978
979
980
981
    def _get_shard_offset_mapping(self, loaded_shard_id: str):
        shard_offset_mapping = {
            "q": 0,
            "k": self.num_heads * self.head_size,
            "v": (self.num_heads + self.num_kv_heads) * self.head_size,
982
983
            "total": (self.num_heads + self.num_kv_heads) * self.head_size
            + self.num_kv_heads * self.v_head_size,
984
985
986
987
988
989
990
        }
        return shard_offset_mapping.get(loaded_shard_id)

    def _get_shard_size_mapping(self, loaded_shard_id: str):
        shard_size_mapping = {
            "q": self.num_heads * self.head_size,
            "k": self.num_kv_heads * self.head_size,
991
            "v": self.num_kv_heads * self.v_head_size,
992
993
994
        }
        return shard_size_mapping.get(loaded_shard_id)

995
996
997
    def _load_fused_module_from_checkpoint(
        self, param: BasevLLMParameter, loaded_weight: torch.Tensor
    ):
998
        """
999
        Handle special case for models where QKV layers are already
1000
        fused on disk. In this case, we have no shard id. This function
1001
        determines the shard id by splitting these layers and then calls
1002
1003
1004
1005
1006
1007
1008
1009
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """
        shard_offsets = [
            # (shard_id, shard_offset, shard_size)
            ("q", 0, self.total_num_heads * self.head_size),
1010
1011
1012
1013
1014
1015
1016
1017
            (
                "k",
                self.total_num_heads * self.head_size,
                self.total_num_kv_heads * self.head_size,
            ),
            (
                "v",
                (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
1018
                self.total_num_kv_heads * self.v_head_size,
1019
            ),
1020
1021
1022
1023
1024
1025
        ]

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
            if (
                isinstance(param, (PackedColumnParameter, PackedvLLMParameter))
                and param.packed_dim == param.output_dim
            ):
                shard_size, shard_offset = param.adjust_shard_indexes_for_packing(
                    shard_size=shard_size, shard_offset=shard_offset
                )

            loaded_weight_shard = loaded_weight.narrow(
                param.output_dim, shard_offset, shard_size
            )
1037
1038
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

1039
1040
1041
1042
    def weight_loader_v2(
        self,
        param: BasevLLMParameter,
        loaded_weight: torch.Tensor,
1043
        loaded_shard_id: str | None = None,
1044
    ):
1045
        if loaded_shard_id is None:  # special case for certain models
1046
            if isinstance(param, PerTensorScaleParameter):
1047
1048
1049
                param.load_qkv_weight(
                    loaded_weight=loaded_weight, shard_id=0, tp_rank=self.tp_rank
                )
1050
                return
1051
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
1052
                param.load_qkv_weight(loaded_weight=loaded_weight, tp_rank=self.tp_rank)
1053
                return
1054
            # TODO: @dsikka - move to parameter.py
1055
1056
1057
1058
1059
1060
1061
1062
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id in ["q", "k", "v"]

        shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
        shard_size = self._get_shard_size_mapping(loaded_shard_id)

1063
        if isinstance(param, BlockQuantScaleParameter):
1064
1065
1066
1067
            weight_block_size = getattr(self, "weight_block_size", None)
            shard_size, shard_offset = adjust_block_scale_shard(
                weight_block_size, shard_size, shard_offset
            )
1068

1069
1070
1071
1072
1073
1074
1075
1076
        param.load_qkv_weight(
            loaded_weight=loaded_weight,
            num_heads=self.num_kv_head_replicas,
            shard_id=loaded_shard_id,
            shard_offset=shard_offset,
            shard_size=shard_size,
            tp_rank=self.tp_rank,
        )
1077

1078
1079
1080
1081
    def weight_loader(
        self,
        param: Parameter,
        loaded_weight: torch.Tensor,
1082
        loaded_shard_id: str | None = None,
1083
    ):
1084
1085
1086
1087
        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
1088
        if is_gguf_weight_type:
1089
            idx_map = {"q": 0, "k": 1, "v": 2}
1090
1091
1092
1093
            if loaded_shard_id is not None:
                param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
1094
                param.shard_weight_type = {k: loaded_weight.item() for k in idx_map}
1095
1096
            return

1097
1098
        if is_gguf_weight:
            output_dim = getattr(param, "output_dim", None)
1099
1100
            shard_size = loaded_weight.size(output_dim) // self.tp_size
            start_idx = self.tp_rank * shard_size
1101

1102
            if loaded_shard_id is not None:
1103
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
1104
1105
1106
1107
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
1108

1109
1110
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
1111

1112
1113
        # Special case for per-tensor scales in fused case.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
1114

1115
        if loaded_shard_id is None:
1116
1117
            # Loaded weight is already fused on disk (qkv).
            # (e.g., Phi-3's qkv_proj).
1118
            if output_dim is None:
1119
                if needs_scalar_to_array:
1120
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
1121
1122
                        param_data, loaded_weight, 0
                    )
1123

1124
1125
1126
1127
1128
1129
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            shard_offsets = [
                # (shard_id, shard_offset, shard_size)
                ("q", 0, self.total_num_heads * self.head_size),
1130
1131
1132
1133
1134
1135
1136
1137
                (
                    "k",
                    self.total_num_heads * self.head_size,
                    self.total_num_kv_heads * self.head_size,
                ),
                (
                    "v",
                    (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
1138
                    self.total_num_kv_heads * self.v_head_size,
1139
                ),
1140
            ]
1141
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1142

1143
1144
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
1145
                # Special case for Quantized Weights.
1146
1147
1148
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
1149
1150
                    shard_size = shard_size // param.packed_factor
                    shard_offset = shard_offset // param.packed_factor
1151

1152
                    # Special case for Marlin.
1153
                    shard_size, shard_offset = adjust_marlin_shard(
1154
1155
                        param, shard_size, shard_offset
                    )
1156

1157
1158
1159
                if use_bitsandbytes_4bit:
                    orig_qkv_offsets = {
                        "q": (0, self.total_num_heads * self.head_size),
1160
1161
1162
1163
1164
1165
1166
                        "k": (
                            self.total_num_heads * self.head_size,
                            self.total_num_kv_heads * self.head_size,
                        ),
                        "v": (
                            (self.total_num_heads + self.total_num_kv_heads)
                            * self.head_size,
1167
                            self.total_num_kv_heads * self.v_head_size,
1168
1169
                        ),
                        "total": (
1170
1171
1172
                            (self.total_num_heads + self.total_num_kv_heads)
                            * self.head_size
                            + self.total_num_kv_heads * self.v_head_size,
1173
1174
                            0,
                        ),
1175
1176
1177
                    }

                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
1178
1179
                        param, orig_qkv_offsets, shard_id
                    )
1180

1181
                loaded_weight_shard = loaded_weight.narrow(
1182
1183
                    output_dim, shard_offset, shard_size
                )
1184
1185
1186
1187
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id in ["q", "k", "v"]
1188
1189

        # If output dim is defined, use the default loading process.
1190
1191
1192
1193
1194
1195
1196
1197
        if output_dim is not None:
            if loaded_shard_id == "q":
                shard_offset = 0
                shard_size = self.num_heads * self.head_size
            elif loaded_shard_id == "k":
                shard_offset = self.num_heads * self.head_size
                shard_size = self.num_kv_heads * self.head_size
            elif loaded_shard_id == "v":
1198
                shard_offset = (self.num_heads + self.num_kv_heads) * self.head_size
1199
                shard_size = self.num_kv_heads * self.v_head_size
1200
1201
1202
1203
1204
1205
1206

            if isinstance(param, BlockQuantScaleParameter):
                weight_block_size = getattr(self, "weight_block_size", None)
                shard_size, shard_offset = adjust_block_scale_shard(
                    weight_block_size, shard_size, shard_offset
                )

1207
            # Special case for Quantized Weights.
1208
1209
1210
1211
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
1212
1213
                shard_size = shard_size // param.packed_factor
                shard_offset = shard_offset // param.packed_factor
1214

1215
                # Special case for Marlin.
1216
                shard_size, shard_offset = adjust_marlin_shard(
1217
1218
                    param, shard_size, shard_offset
                )
1219

1220
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1221
1222
1223
1224
1225
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

1226
            if use_bitsandbytes_4bit:
1227
1228
                orig_qkv_offsets = {
                    "q": (0, self.num_heads * self.head_size),
1229
1230
1231
1232
1233
1234
                    "k": (
                        self.num_heads * self.head_size,
                        self.num_kv_heads * self.head_size,
                    ),
                    "v": (
                        (self.num_heads + self.num_kv_heads) * self.head_size,
1235
                        self.num_kv_heads * self.v_head_size,
1236
1237
                    ),
                    "total": (
1238
1239
                        (self.num_heads + self.num_kv_heads) * self.head_size
                        + self.num_kv_heads * self.v_head_size,
1240
1241
                        0,
                    ),
1242
                }
1243
                shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
1244
1245
                    param, orig_qkv_offsets, loaded_shard_id
                )
1246

1247
            param_data = param_data.narrow(output_dim, shard_offset, shard_size)
1248
            if loaded_shard_id == "q":
1249
                shard_rank = self.tp_rank
1250
            else:
1251
1252
                shard_rank = self.tp_rank // self.num_kv_head_replicas
            start_idx = shard_rank * shard_size
1253

1254
            if not is_sharded_weight:
1255
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
1256

1257
1258
1259
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
1260
1261
                param_data, loaded_weight, loaded_shard_id
            )
1262
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
1263
1264
1265
1266
1267
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "QKVParallelLinear, assume the weight is the same "
1268
1269
                    "for all partitions."
                )
1270

1271
1272
1273
1274
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)


1275
# --8<-- [start:row_parallel_linear]
1276
@CustomOp.register("row_parallel_linear")
1277
class RowParallelLinear(LinearBase):
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
    """Linear layer with row parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its first dimension and X along its second dimension as:
               -   -
              | A_1 |
              | .   |
          A = | .   |        X = [X_1, ..., X_p]
              | .   |
              | A_p |
               -   -
    Arguments:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias. Note that bias is not parallelized.
        input_is_parallel: If true, we assume that the input is already
                           split across the GPUs and we do not split
                           again.
        skip_bias_add: This was added to enable performance optimization where
                       bias can be fused with other element-wise operations.
                       We skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
1300
1301
1302
        reduce_results: If true, call all-reduce on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y = X_iA_i
1303
        quant_config: Quantization configure.
1304
1305
1306
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.down_proj)
        return_bias: If true, return bias together with outputs in forward pass.
1307
        disable_tp: If true, weights matrix won't be sharded through tp rank.
1308
1309
    """

1310
1311
    # --8<-- [end:row_parallel_linear]

1312
1313
1314
1315
1316
1317
1318
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        input_is_parallel: bool = True,
        skip_bias_add: bool = False,
1319
        params_dtype: torch.dtype | None = None,
1320
        reduce_results: bool = True,
1321
        quant_config: QuantizationConfig | None = None,
1322
1323
1324
        prefix: str = "",
        *,
        return_bias: bool = True,
1325
        disable_tp: bool = False,
1326
    ):
1327
        # Divide the weight matrix along the first dimension.
1328
1329
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
1330
1331
1332
1333
        self.input_size_per_partition = divide(input_size, self.tp_size)
        self.output_size_per_partition = output_size
        self.output_partition_sizes = [output_size]

1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
        super().__init__(
            input_size,
            output_size,
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
1344

1345
1346
1347
        self.input_is_parallel = input_is_parallel
        self.reduce_results = reduce_results

1348
        assert self.quant_method is not None
1349
1350
1351
        self.quant_method.create_weights(
            layer=self,
            input_size_per_partition=self.input_size_per_partition,
1352
            output_partition_sizes=self.output_partition_sizes,
1353
1354
1355
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
1356
            weight_loader=(
1357
1358
1359
1360
1361
                self.weight_loader_v2
                if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
                else self.weight_loader
            ),
        )
1362
        if not reduce_results and (bias and not skip_bias_add):
1363
1364
1365
1366
            raise ValueError(
                "When not reduce the results, adding bias to the "
                "results can lead to incorrect results"
            )
1367
1368

        if bias:
1369
1370
1371
1372
1373
1374
1375
1376
            self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype))
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
1377
1378
        else:
            self.register_parameter("bias", None)
1379
        self.update_param_tp_status()
1380
1381
1382

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        input_dim = getattr(param, "input_dim", None)
1383
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1384
1385
1386
1387
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
1388
1389
1390
1391
1392
1393
1394
1395
1396

        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
1397
1398
            weight_shape = list(loaded_weight.shape)
            if input_dim:
1399
                weight_shape[input_dim] = weight_shape[input_dim] // self.tp_size
1400
            param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
1401

1402
        param_data = param.data
1403
        if input_dim is not None and not is_sharded_weight:
1404
            shard_size = param_data.shape[input_dim]
1405
            start_idx = self.tp_rank * shard_size
1406
            loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)
1407

1408
1409
1410
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
1411
1412
            loaded_weight = loaded_weight.reshape(1)

1413
1414
1415
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

1416
    def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor):
1417
1418
1419
1420
1421
1422
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)

1423
1424
        param.load_row_parallel_weight(loaded_weight=loaded_weight)

1425
    def forward(
1426
1427
        self,
        input_,
1428
    ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
1429
1430
1431
1432
        if self.input_is_parallel:
            input_parallel = input_
        else:
            splitted_input = split_tensor_along_last_dim(
1433
1434
                input_, num_partitions=self.tp_size
            )
1435
            input_parallel = splitted_input[self.tp_rank].contiguous()
1436
1437

        # Matrix multiply.
1438
        assert self.quant_method is not None
1439
1440
1441
        # Only fuse bias add into GEMM for rank 0 (this ensures that
        # bias will not get added more than once in TP>1 case)
        bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
1442
1443
        output_parallel = self.quant_method.apply(self, input_parallel, bias_)

1444
        if self.reduce_results and self.tp_size > 1:
1445
            output = tensor_model_parallel_all_reduce(output_parallel)
1446
        else:
1447
1448
            output = output_parallel

1449
1450
        if not self.return_bias:
            return output
1451
        output_bias = self.bias if self.skip_bias_add else None
1452
        return output, output_bias
1453
1454

    def extra_repr(self) -> str:
1455
        s = f"in_features={self.input_size_per_partition}"
1456
1457
1458
1459
1460
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={self.tp_size}"
        s += f", reduce_results={self.reduce_results}"
        return s