linear.py 57.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
from abc import abstractmethod
6
from typing import Any
7
8

import torch
9
from torch.nn.parameter import Parameter, UninitializedParameter
10

11
12
13
14
15
16
17
18
from vllm.distributed import (
    divide,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
    split_tensor_along_last_dim,
    tensor_model_parallel_all_gather,
    tensor_model_parallel_all_reduce,
)
19
from vllm.logger import init_logger
20
from vllm.model_executor.custom_op import CustomOp
21
from vllm.model_executor.layers.quantization.base_config import (
22
23
24
    QuantizationConfig,
    QuantizeMethodBase,
)
25
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
26
27
28
29
30
31
32
33
34
from vllm.model_executor.parameter import (
    BasevLLMParameter,
    BlockQuantScaleParameter,
    ModelWeightParameter,
    PackedColumnParameter,
    PackedvLLMParameter,
    PerTensorScaleParameter,
    RowvLLMParameter,
)
35
from vllm.model_executor.utils import set_weight_attrs
36
from vllm.platforms import current_platform
37
38
39

logger = init_logger(__name__)

40
WEIGHT_LOADER_V2_SUPPORTED = [
41
    "UnquantizedLinearMethod",
42
    "CompressedTensorsLinearMethod",
43
    "CompressedTensorsLinearTransformMethod",
44
45
    "BitBLASLinearMethod",
    "GPTQBitBLASLinearMethod",
46
47
48
49
50
51
52
53
54
55
    "AWQMarlinLinearMethod",
    "AWQLinearMethod",
    "GPTQMarlinLinearMethod",
    "Fp8LinearMethod",
    "MarlinLinearMethod",
    "GPTQMarlin24LinearMethod",
    "TPUInt8LinearMethod",
    "GPTQLinearMethod",
    "FBGEMMFp8LinearMethod",
    "ModelOptFp8LinearMethod",
56
57
    "ModelOptFp8PcPtLinearMethod",
    "ModelOptFp8PbWoLinearMethod",
58
59
60
61
62
    "IPEXAWQLinearMethod",
    "IPEXGPTQLinearMethod",
    "HQQMarlinMethod",
    "QuarkLinearMethod",
    "ModelOptNvFp4LinearMethod",
63
    "PetitNvFp4LinearMethod",
64
]
65

66

67
68
69
def adjust_bitblas_shard(param, shard_size, shard_offset):
    bitblas_tile_size = getattr(param, "bitblas_tile_size", None)
    if bitblas_tile_size is not None:
70
        return (shard_size // bitblas_tile_size, shard_offset // bitblas_tile_size)
71
72
73
74

    return shard_size, shard_offset


75
76
77
78
79
80
81
82
def adjust_marlin_shard(param, shard_size, shard_offset):
    marlin_tile_size = getattr(param, "marlin_tile_size", None)
    if marlin_tile_size is None:
        return shard_size, shard_offset

    return shard_size * marlin_tile_size, shard_offset * marlin_tile_size


83
84
85
86
87
88
89
90
def adjust_block_scale_shard(weight_block_size, shard_size, shard_offset):
    assert weight_block_size is not None
    block_n = weight_block_size[0]
    shard_offset = (shard_offset + block_n - 1) // block_n
    shard_size = (shard_size + block_n - 1) // block_n
    return shard_size, shard_offset


91
92
93
def adjust_bitsandbytes_4bit_shard(
    param: Parameter, shard_offsets: dict[str, tuple[int, int]], loaded_shard_id: str
) -> tuple[int, int]:
94
95
    """Adjust the quantization offsets and sizes for BitsAndBytes sharding."""

96
97
    total, _ = shard_offsets["total"]
    orig_offset, orig_size = shard_offsets[loaded_shard_id]
98
99
100
101
102
103
104
105

    quantized_total = param.data.shape[0]
    quantized_offset = orig_offset * quantized_total // total
    quantized_size = orig_size * quantized_total // total

    return quantized_size, quantized_offset


106
107
108
def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
    """For fused modules (QKV and MLP) we have an array of length
    N that holds 1 scale for each "logical" matrix. So the param
109
110
    is an array of length N. The loaded_weight corresponds to
    one of the shards on disk. Here, we slice the param based on
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
    the shard_id for loading.
    """
    qkv_idxs = {"q": 0, "k": 1, "v": 2}

    if isinstance(shard_id, str):
        shard_id = qkv_idxs[shard_id]
    elif not isinstance(shard_id, int):
        raise ValueError(f"Unknown Shard Id {shard_id}")

    # AutoFP8 scales do not have a shape
    # compressed-tensors scales do have a shape
    if len(loaded_weight.shape) != 0:
        assert loaded_weight.shape[0] == 1
        loaded_weight = loaded_weight[0]

    return param[shard_id], loaded_weight


129
130
131
132
133
134
135
136
# TODO(Isotr0py): We might need a more flexible structure to handle
# bitsandbytes shard offsets.
def left_shift_bitsandbytes_4bit_shard(bnb_weight_attrs: dict[str, Any]):
    """
    Separate the BitsAndBytes 4-bit shard.

    For example, given bnb weight attributes as below:
    {
137
        'bnb_shard_offsets': array([0, 4, 8, 16]),
138
139
140
141
142
        'bnb_quant_state': {0: ..., 1: ..., 2: ...},
    }

    The function will return:
    {
143
        'bnb_shard_offsets': array([0, 4]),
144
145
146
147
148
149
150
151
152
153
154
155
156
157
        'bnb_quant_state': {0: ...},
    }
    and
    {
        'bnb_shard_offsets': array([0, 4, 12]),
        'bnb_quant_state': {0: ..., 1: ...},
    }
    """
    shard_offsets = bnb_weight_attrs["bnb_shard_offsets"]
    offset_l = shard_offsets[:2]
    offset_r = shard_offsets[1:] - shard_offsets[1]
    quant_state_l = {0: bnb_weight_attrs["bnb_quant_state"][0]}
    quant_state_r = {
        i - 1: bnb_weight_attrs["bnb_quant_state"][i]
158
        for i in range(1, len(shard_offsets) - 1)
159
160
161
162
163
164
    }
    left = dict(bnb_shard_offsets=offset_l, bnb_quant_state=quant_state_l)
    right = dict(bnb_shard_offsets=offset_r, bnb_quant_state=quant_state_r)
    return left, right


165
class LinearMethodBase(QuantizeMethodBase):
166
167
168
    """Base class for different (maybe quantized) linear methods."""

    @abstractmethod
169
170
171
172
173
174
175
176
177
178
179
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        """Create weights for a linear layer.
180
           The weights will be set as attributes of the layer.
181

182
183
184
        Args:
            layer: The layer that is using the LinearMethodBase factory.
            input_size_per_partition: Size of the weight input dim on rank X.
185
            output_partition_sizes: Sizes of the output dim of each logical
186
187
188
189
190
191
                weight on rank X. E.g., output_partition_sizes for QKVLinear
                is a list contains the width of Wq, Wk, Wv on rank X.
            input_size: Size of the input dim of the weight across all ranks.
            output_size: Size of the output dim of the weight across all ranks.
            params_dtype: Datatype of the parameters.
        """
192
193
194
        raise NotImplementedError

    @abstractmethod
195
196
197
198
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
199
        bias: torch.Tensor | None = None,
200
    ) -> torch.Tensor:
201
202
        """Apply the weights in layer to the input tensor.
        Expects create_weights to have been called before on the layer."""
203
204
205
206
        raise NotImplementedError


class UnquantizedLinearMethod(LinearMethodBase):
207
    """Linear method without quantization."""
208

209
210
211
212
213
214
215
216
217
218
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
219
220
221
222
        # This method creates unquantized linear weights.
        # The weights are not quantized, and they are not sharded.
        # The amount of memory allocated for the weights is
        # sum(output_partition_sizes) * input_size_per_partition.
223
224
225
226
227
228
229
230
231
232
233
        weight_loader = extra_weight_attrs.pop("weight_loader")
        weight = ModelWeightParameter(
            data=torch.empty(
                sum(output_partition_sizes),
                input_size_per_partition,
                dtype=params_dtype,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
234

235
236
        layer.register_parameter("weight", weight)
        set_weight_attrs(weight, extra_weight_attrs)
237

238
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
239
        if current_platform.is_cpu():
240
            from vllm.model_executor.layers.utils import dispatch_cpu_unquantized_gemm
241

242
            dispatch_cpu_unquantized_gemm(layer, remove_weight=True)
243

244
245
246
247
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
248
        bias: torch.Tensor | None = None,
249
    ) -> torch.Tensor:
250
        return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
251
252


253
class LinearBase(CustomOp):
254
    """Base linear layer.
255
256
257
258
259
260

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
261
        quant_config: Quantization configure.
262
        prefix: Prefix for parameter names.
263
        return_bias: If true, return bias together with outputs in forward pass.
264
        disable_tp: If true, tensor parallelism will be disabled for this layer.
265
266
267
268
269
270
271
    """

    def __init__(
        self,
        input_size: int,
        output_size: int,
        skip_bias_add: bool = False,
272
273
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
274
        prefix: str = "",
275
276
        *,
        return_bias: bool = True,
277
        disable_tp: bool = False,
278
279
280
281
282
283
284
285
286
287
    ):
        super().__init__()

        # Keep input parameters
        self.input_size = input_size
        self.output_size = output_size
        self.skip_bias_add = skip_bias_add
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype
288
289
        self.quant_config = quant_config
        self.prefix = prefix
290
        self.allow_fp8_block_shape_mismatch = False
291
        if quant_config is None:
292
            self.quant_method: QuantizeMethodBase | None = UnquantizedLinearMethod()
293
        else:
294
            self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
295
        self.return_bias = return_bias
296
        self.disable_tp = disable_tp
297
298
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
299

300
    def update_param_tp_status(self):
301
302
303
304
        for param in self.parameters():
            if isinstance(param, BasevLLMParameter):
                param.tp_rank = self.tp_rank
                param.tp_size = self.tp_size
305
306


307
# --8<-- [start:replicated_linear]
308
@CustomOp.register("replicated_linear")
309
310
311
312
313
314
315
316
317
318
class ReplicatedLinear(LinearBase):
    """Replicated linear layer.

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        bias: If true, add bias.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
        quant_config: Quantization configure.
319
320
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
321
        return_bias: If true, return bias together with outputs in forward pass.
322
        disable_tp: Take no effect for replicated linear layers.
323
324
    """

325
326
    # --8<-- [end:replicated_linear]

327
328
329
330
331
332
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        skip_bias_add: bool = False,
333
334
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
335
336
337
        prefix: str = "",
        *,
        return_bias: bool = True,
338
        disable_tp: bool = False,
339
    ):
340
341
342
343
344
345
        # If MergedReplicatedLinear, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = self.output_sizes
        else:
            self.output_partition_sizes = [output_size]

346
347
348
349
350
351
352
353
354
355
        super().__init__(
            input_size,
            output_size,
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
356

357
358
        # All the linear layer supports quant method.
        assert self.quant_method is not None
359
360
361
362
363
364
365
366
367
        self.quant_method.create_weights(
            self,
            self.input_size,
            self.output_partition_sizes,
            self.input_size,
            self.output_size,
            self.params_dtype,
            weight_loader=self.weight_loader,
        )
368

369
370
        if bias:
            self.bias = Parameter(
371
372
373
374
375
376
377
378
379
                torch.empty(self.output_size, dtype=self.params_dtype)
            )
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
380
381
382
        else:
            self.register_parameter("bias", None)

383
384
385
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        # If the weight on disk does not have a shape, give it one
        # (such scales for AutoFp8).
386
387
388
389
390
391
392
393
394
395
396
        # Special case for GGUF

        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
            param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)

397
398
399
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)

400
401
        assert param.size() == loaded_weight.size(), (
            f"Tried to load weights of size {loaded_weight.size()}"
402
403
            f"to a parameter of size {param.size()}"
        )
404
405
        param.data.copy_(loaded_weight)

406
    def forward(
407
408
        self,
        x: torch.Tensor,
409
    ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
410
        bias = self.bias if not self.skip_bias_add else None
411
        assert self.quant_method is not None
412

413
        output = self.quant_method.apply(self, x, bias)
414
        output_bias = self.bias if self.skip_bias_add else None
415

416
417
        if not self.return_bias:
            return output
418
419
        return output, output_bias

420
421
422
423
424
425
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        return s

426

427
# --8<-- [start:column_parallel_linear]
428
@CustomOp.register("column_parallel_linear")
429
class ColumnParallelLinear(LinearBase):
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
    """Linear layer with column parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its second dimension as A = [A_1, ..., A_p].

    Args:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y_i = XA_i
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
446
        quant_config: Quantization configure.
James Fleming's avatar
James Fleming committed
447
448
        output_sizes: list of output sizes packed into one output, like for QKV
                       the list would be size 3.
449
        prefix: The name of the layer in the state dict, including all parents
450
                        (e.g. model.layers.0.qkv_proj)
451
452
        return_bias: If true, return bias together with outputs in forward pass.
        disable_tp: If true, weights matrix won't be sharded through tp rank.
453
454
    """

455
456
    # --8<-- [end:column_parallel_linear]

457
458
459
460
461
462
463
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
464
465
466
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
        output_sizes: list[int] | None = None,
467
468
469
        prefix: str = "",
        *,
        return_bias: bool = True,
470
        disable_tp: bool = False,
471
    ):
472
        # Divide the weight matrix along the last dimension.
473
474
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
475
476
        self.input_size_per_partition = input_size
        self.output_size_per_partition = divide(output_size, self.tp_size)
477
478
479
480
        self.output_partition_sizes = [self.output_size_per_partition]
        # If QKV or MergedColumn, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = [
481
                divide(output_size, self.tp_size) for output_size in self.output_sizes
482
483
            ]

484
485
486
487
488
489
490
491
492
493
        super().__init__(
            input_size,
            output_size,
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
494

495
        self._maybe_allow_fp8_block_shape_mismatch()
496
497
        self.gather_output = gather_output

James Fleming's avatar
James Fleming committed
498
499
        if output_sizes is None:
            output_sizes = [output_size]
500

501
        assert self.quant_method is not None
502
503
        self.quant_method.create_weights(
            layer=self,
504
            input_size_per_partition=self.input_size_per_partition,
505
506
507
508
            output_partition_sizes=self.output_partition_sizes,
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
509
            weight_loader=(
510
511
512
513
514
                self.weight_loader_v2
                if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
                else self.weight_loader
            ),
        )
515
516
        if bias:
            self.bias = Parameter(
517
518
519
520
521
522
523
524
525
                torch.empty(self.output_size_per_partition, dtype=params_dtype)
            )
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
526
527
        else:
            self.register_parameter("bias", None)
528
        self.update_param_tp_status()
529

530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
    def _maybe_allow_fp8_block_shape_mismatch(self) -> None:
        quant_config = getattr(self, "quant_config", None)
        weight_block = getattr(quant_config, "weight_block_size", None)
        if (
            weight_block is None
            or len(weight_block) < 1
            or len(self.output_partition_sizes) <= 1
        ):
            return

        try:
            block_n = int(weight_block[0])
        except (ValueError, TypeError):
            return

        if block_n <= 0:
            return

        if any(size % block_n != 0 for size in self.output_partition_sizes):
            self.allow_fp8_block_shape_mismatch = True
            logger.debug(
                "Allowing FP8 block shape mismatch for %s (block_n=%d, partitions=%s)",
                getattr(self, "prefix", "<unknown>"),
                block_n,
                self.output_partition_sizes,
            )

557
558
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        output_dim = getattr(param, "output_dim", None)
559

560
561
562
563
564
565
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

566
567
568
569
570
571
572
573
        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
574
575
            final_shape = list(loaded_weight.shape)
            if output_dim is not None:
576
                assert final_shape[output_dim] % self.tp_size == 0
577
                final_shape[output_dim] = final_shape[output_dim] // self.tp_size
578
            param.materialize(final_shape, dtype=loaded_weight.dtype)
579

580
        param_data = param.data
581
        if output_dim is not None and not is_sharded_weight:
582
            shard_size = param_data.shape[output_dim]
583
            start_idx = self.tp_rank * shard_size
584
            loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
585
586
587
588
589

        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)
590

591
592
593
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

594
    def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor):
595
596
597
598
599
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)
600
601
        param.load_column_parallel_weight(loaded_weight=loaded_weight)

602
    def forward(
603
604
        self,
        input_,
605
    ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
606
607
608
        bias = self.bias if not self.skip_bias_add else None

        # Matrix multiply.
609
        assert self.quant_method is not None
610
        output_parallel = self.quant_method.apply(self, input_, bias)
611

612
        if self.gather_output and self.tp_size > 1:
613
614
615
616
617
            # All-gather across the partitions.
            output = tensor_model_parallel_all_gather(output_parallel)
        else:
            output = output_parallel
        output_bias = self.bias if self.skip_bias_add else None
618
619
        if not self.return_bias:
            return output
620
621
        return output, output_bias

622
623
624
625
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size_per_partition}"
        s += f", bias={self.bias is not None}"
626
        s += f", tp_size={self.tp_size}"
627
628
629
        s += f", gather_output={self.gather_output}"
        return s

630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648

class MergedColumnParallelLinear(ColumnParallelLinear):
    """Packed linear layers with column parallelism.

    Similar to ColumnParallelLinear, but the weight matrix is concatenated
    along the output dimension. When the weight matrix is loaded, the
    different partitions are sharded separately.

    Args:
        input_size: input dimension of the linear layer.
        output_sizes: list of output dimensions of the linear layer.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make the output
                       available to all GPUs, otherwise, every GPU will have
                       its own output.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
649
        quant_config: Quantization configure.
650
651
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
652
        return_bias: If true, return bias together with outputs in forward pass.
653
654
        disable_tp: If true, all weights matrix won't be sharded, this layer
                    will be treated as a "Replicated" MergedLinear.
655
656
    """

657
658
659
660
661
662
663
    def __init__(
        self,
        input_size: int,
        output_sizes: list[int],
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
664
665
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
666
667
668
        prefix: str = "",
        *,
        return_bias: bool = True,
669
        disable_tp: bool = False,
670
    ):
671
        self.output_sizes = output_sizes
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0

        assert all(output_size % self.tp_size == 0 for output_size in output_sizes)
        super().__init__(
            input_size=input_size,
            output_size=sum(output_sizes),
            bias=bias,
            gather_output=gather_output,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            quant_config=quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
James Fleming's avatar
James Fleming committed
688

689
690
691
692
    def weight_loader(
        self,
        param: Parameter,
        loaded_weight: torch.Tensor,
693
        loaded_shard_id: int | None = None,
694
    ):
695
696
697
698
699
        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
700
701
702
703
704
            if loaded_shard_id is not None:
                param.data[loaded_shard_id].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
                param.shard_weight_type = {
705
                    i: loaded_weight.item() for i, _ in enumerate(self.output_sizes)
706
                }
707
708
            return

709
710
        if is_gguf_weight:
            output_dim = getattr(param, "output_dim", None)
711
712
            shard_size = loaded_weight.size(output_dim) // self.tp_size
            start_idx = self.tp_rank * shard_size
713

714
            if loaded_shard_id is not None:
715
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
716
717
718
719
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
720

721
722
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
723
724
        # Special case for per-tensor scale to load scalar into fused array.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
725

726
        if loaded_shard_id is None:
727
728
            # Loaded weight is already fused on disk (mlp).
            # (e.g., Phi-3's gate_up_proj).
729
            if output_dim is None:
730
                if needs_scalar_to_array:
731
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
732
733
                        param_data, loaded_weight, 0
                    )
734

735
736
737
738
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            current_shard_offset = 0
739
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
740
            shard_offsets: list[tuple[int, int, int]] = []
741
742
743
744
745
            for i, output_size in enumerate(self.output_sizes):
                shard_offsets.append((i, current_shard_offset, output_size))
                current_shard_offset += output_size
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
746
                # Special case for Quantization.
747
748
749
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
750
751
                    shard_size = shard_size // param.packed_factor
                    shard_offset = shard_offset // param.packed_factor
752
                    # Special case for Marlin.
753
                    shard_size, shard_offset = adjust_marlin_shard(
754
755
                        param, shard_size, shard_offset
                    )
756

757
                shard_size, shard_offset = adjust_bitblas_shard(
758
759
                    param, shard_size, shard_offset
                )
760

761
                if use_bitsandbytes_4bit:
762
763
764
765
766
767
768
                    index = list(itertools.accumulate([0] + self.output_sizes))
                    orig_offsets = {
                        str(i): (index[i], size)
                        for i, size in enumerate(self.output_sizes)
                    }
                    orig_offsets["total"] = (self.output_size, 0)
                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
769
770
                        param, orig_offsets, str(shard_id)
                    )
771

772
                loaded_weight_shard = loaded_weight.narrow(
773
774
                    output_dim, shard_offset, shard_size
                )
775
776
777
778
779
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id < len(self.output_sizes)
        if output_dim is not None:
780
781
782
783
784
785
786
787
788
789
790
791
            shard_offset = sum(self.output_sizes[:loaded_shard_id])
            shard_size = self.output_sizes[loaded_shard_id]

            if isinstance(param, BlockQuantScaleParameter):
                weight_block_size = getattr(self, "weight_block_size", None)
                shard_size, shard_offset = adjust_block_scale_shard(
                    weight_block_size, shard_size, shard_offset
                )

            shard_offset //= self.tp_size
            shard_size //= self.tp_size

792
            # Special case for quantization.
793
794
795
796
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
797
798
                shard_size = shard_size // param.packed_factor
                shard_offset = shard_offset // param.packed_factor
799
                # Special case for Marlin.
800
                shard_size, shard_offset = adjust_marlin_shard(
801
802
                    param, shard_size, shard_offset
                )
803
            shard_size, shard_offset = adjust_bitblas_shard(
804
805
                param, shard_size, shard_offset
            )
806

807
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
808
809
810
811
812
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

813
            if use_bitsandbytes_4bit:
814
                shard_size = loaded_weight.shape[output_dim]
815
                shard_offset = loaded_weight.shape[output_dim] * loaded_shard_id
816

817
            param_data = param_data.narrow(output_dim, shard_offset, shard_size)
818
            start_idx = self.tp_rank * shard_size
819
            if not is_sharded_weight:
820
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
821
822
823
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
824
825
                param_data, loaded_weight, loaded_shard_id
            )
826

827
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
828
829
830
831
832
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "MergedColumnParallelLinear, assume the weight is "
833
834
                    "the same for all partitions."
                )
835

836
837
838
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

839
840
841
    def _load_fused_module_from_checkpoint(
        self, param: BasevLLMParameter, loaded_weight: torch.Tensor
    ):
842
843
844
        """
        Handle special case for models where MLP layers are already
        fused on disk. In this case, we have no shard id. This function
845
        determines the shard id by splitting these layers and then calls
846
847
848
849
850
851
852
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """

        current_shard_offset = 0
853
        shard_offsets: list[tuple[int, int, int]] = []
854
855
856
857
858
859
860
861
        for i, output_size in enumerate(self.output_sizes):
            shard_offsets.append((i, current_shard_offset, output_size))
            current_shard_offset += output_size

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
862
863
864
865
866
867
868
869
870
871
872
            if (
                isinstance(param, (PackedColumnParameter, PackedvLLMParameter))
                and param.packed_dim == param.output_dim
            ):
                shard_size, shard_offset = param.adjust_shard_indexes_for_packing(
                    shard_size=shard_size, shard_offset=shard_offset
                )

            loaded_weight_shard = loaded_weight.narrow(
                param.output_dim, shard_offset, shard_size
            )
873
874
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

875
876
877
878
    def weight_loader_v2(
        self,
        param: BasevLLMParameter,
        loaded_weight: torch.Tensor,
879
        loaded_shard_id: int | None = None,
880
    ):
881
        if loaded_shard_id is None:
882
            if isinstance(param, PerTensorScaleParameter):
883
                param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0)
884
                return
885
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
886
                param.load_merged_column_weight(loaded_weight=loaded_weight)
887
                return
888
            # TODO: @dsikka - move to parameter.py
889
890
891
892
893
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id < len(self.output_sizes)

894
895
896
        shard_offset = sum(self.output_sizes[:loaded_shard_id])
        shard_size = self.output_sizes[loaded_shard_id]

897
        if isinstance(param, BlockQuantScaleParameter):
898
899
900
            weight_block_size = getattr(self, "weight_block_size", None)
            shard_size, shard_offset = adjust_block_scale_shard(
                weight_block_size, shard_size, shard_offset
901
            )
902
903
904

        shard_offset //= self.tp_size
        shard_size //= self.tp_size
905

906
907
908
909
910
911
912
        param.load_merged_column_weight(
            loaded_weight=loaded_weight,
            shard_id=loaded_shard_id,
            shard_offset=shard_offset,
            shard_size=shard_size,
            tp_rank=self.tp_rank,
        )
913

914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935

class QKVParallelLinear(ColumnParallelLinear):
    """Linear layers for the attention's QKV transformation.

    Linear layers for the linear transformation of the query, key, and value
    vectors in the attention layer. The weight matrix is concatenated along
    the output dimension. The layer is parallelized along the head dimension.
    When the number of key/value heads is smaller than the number of query
    heads (e.g., multi-query/grouped-query attention), the key/value head may
    be replicated while the query heads are partitioned.

    Args:
        hidden_size: input hidden state size of the transformer.
        head_size: size of each attention head.
        total_num_heads: total number of attention query heads.
        total_num_kv_heads: total number of attention key/value heads. If
                            None, assume total_num_kv_heads = total_num_heads.
        bias: If true, add bias.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
936
        quant_config: Quantization configure.
937
938
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
939
        return_bias: If true, return bias together with outputs in forward pass.
940
        disable_tp: If true, weights matrix won't be sharded through tp rank.
941
942
    """

943
944
945
946
947
    def __init__(
        self,
        hidden_size: int,
        head_size: int,
        total_num_heads: int,
948
        total_num_kv_heads: int | None = None,
949
950
        bias: bool = True,
        skip_bias_add: bool = False,
951
952
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
953
954
955
        prefix: str = "",
        *,
        return_bias: bool = True,
956
        disable_tp: bool = False,
957
        v_head_size: int | None = None,
958
    ):
959
960
        self.hidden_size = hidden_size
        self.head_size = head_size
961
        self.v_head_size = v_head_size if v_head_size is not None else head_size
962
963
964
965
966
        self.total_num_heads = total_num_heads
        if total_num_kv_heads is None:
            total_num_kv_heads = total_num_heads
        self.total_num_kv_heads = total_num_kv_heads
        # Divide the weight matrix along the last dimension.
967
        tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
968
969
970
        self.num_heads = divide(self.total_num_heads, tp_size)
        if tp_size >= self.total_num_kv_heads:
            self.num_kv_heads = 1
971
            self.num_kv_head_replicas = divide(tp_size, self.total_num_kv_heads)
972
973
974
975
        else:
            self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
            self.num_kv_head_replicas = 1
        input_size = self.hidden_size
976
        output_size = (
977
978
979
980
            self.num_heads * self.head_size
            + self.num_kv_heads * self.head_size
            + self.num_kv_heads * self.v_head_size
        ) * tp_size
981
982
983
        self.output_sizes = [
            self.num_heads * self.head_size * tp_size,  # q_proj
            self.num_kv_heads * self.head_size * tp_size,  # k_proj
984
            self.num_kv_heads * self.v_head_size * tp_size,  # v_proj
James Fleming's avatar
James Fleming committed
985
986
        ]

987
988
989
990
991
992
993
994
995
996
997
998
        super().__init__(
            input_size=input_size,
            output_size=output_size,
            bias=bias,
            gather_output=False,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            quant_config=quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
999

1000
1001
1002
1003
1004
    def _get_shard_offset_mapping(self, loaded_shard_id: str):
        shard_offset_mapping = {
            "q": 0,
            "k": self.num_heads * self.head_size,
            "v": (self.num_heads + self.num_kv_heads) * self.head_size,
1005
1006
            "total": (self.num_heads + self.num_kv_heads) * self.head_size
            + self.num_kv_heads * self.v_head_size,
1007
1008
1009
1010
1011
1012
1013
        }
        return shard_offset_mapping.get(loaded_shard_id)

    def _get_shard_size_mapping(self, loaded_shard_id: str):
        shard_size_mapping = {
            "q": self.num_heads * self.head_size,
            "k": self.num_kv_heads * self.head_size,
1014
            "v": self.num_kv_heads * self.v_head_size,
1015
1016
1017
        }
        return shard_size_mapping.get(loaded_shard_id)

1018
1019
1020
    def _load_fused_module_from_checkpoint(
        self, param: BasevLLMParameter, loaded_weight: torch.Tensor
    ):
1021
        """
1022
        Handle special case for models where QKV layers are already
1023
        fused on disk. In this case, we have no shard id. This function
1024
        determines the shard id by splitting these layers and then calls
1025
1026
1027
1028
1029
1030
1031
1032
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """
        shard_offsets = [
            # (shard_id, shard_offset, shard_size)
            ("q", 0, self.total_num_heads * self.head_size),
1033
1034
1035
1036
1037
1038
1039
1040
            (
                "k",
                self.total_num_heads * self.head_size,
                self.total_num_kv_heads * self.head_size,
            ),
            (
                "v",
                (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
1041
                self.total_num_kv_heads * self.v_head_size,
1042
            ),
1043
1044
1045
1046
1047
1048
        ]

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
            if (
                isinstance(param, (PackedColumnParameter, PackedvLLMParameter))
                and param.packed_dim == param.output_dim
            ):
                shard_size, shard_offset = param.adjust_shard_indexes_for_packing(
                    shard_size=shard_size, shard_offset=shard_offset
                )

            loaded_weight_shard = loaded_weight.narrow(
                param.output_dim, shard_offset, shard_size
            )
1060
1061
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

1062
1063
1064
1065
    def weight_loader_v2(
        self,
        param: BasevLLMParameter,
        loaded_weight: torch.Tensor,
1066
        loaded_shard_id: str | None = None,
1067
    ):
1068
        if loaded_shard_id is None:  # special case for certain models
1069
            if isinstance(param, PerTensorScaleParameter):
1070
1071
1072
                param.load_qkv_weight(
                    loaded_weight=loaded_weight, shard_id=0, tp_rank=self.tp_rank
                )
1073
                return
1074
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
1075
                param.load_qkv_weight(loaded_weight=loaded_weight, tp_rank=self.tp_rank)
1076
                return
1077
            # TODO: @dsikka - move to parameter.py
1078
1079
1080
1081
1082
1083
1084
1085
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id in ["q", "k", "v"]

        shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
        shard_size = self._get_shard_size_mapping(loaded_shard_id)

1086
        if isinstance(param, BlockQuantScaleParameter):
1087
1088
1089
1090
            weight_block_size = getattr(self, "weight_block_size", None)
            shard_size, shard_offset = adjust_block_scale_shard(
                weight_block_size, shard_size, shard_offset
            )
1091

1092
1093
1094
1095
1096
1097
1098
1099
        param.load_qkv_weight(
            loaded_weight=loaded_weight,
            num_heads=self.num_kv_head_replicas,
            shard_id=loaded_shard_id,
            shard_offset=shard_offset,
            shard_size=shard_size,
            tp_rank=self.tp_rank,
        )
1100

1101
1102
1103
1104
    def weight_loader(
        self,
        param: Parameter,
        loaded_weight: torch.Tensor,
1105
        loaded_shard_id: str | None = None,
1106
    ):
1107
1108
1109
1110
        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
1111
        if is_gguf_weight_type:
1112
            idx_map = {"q": 0, "k": 1, "v": 2}
1113
1114
1115
1116
            if loaded_shard_id is not None:
                param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
1117
                param.shard_weight_type = {k: loaded_weight.item() for k in idx_map}
1118
1119
            return

1120
1121
        if is_gguf_weight:
            output_dim = getattr(param, "output_dim", None)
1122
1123
            shard_size = loaded_weight.size(output_dim) // self.tp_size
            start_idx = self.tp_rank * shard_size
1124

1125
            if loaded_shard_id is not None:
1126
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
1127
1128
1129
1130
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
1131

1132
1133
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
1134

1135
1136
        # Special case for per-tensor scales in fused case.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
1137

1138
        if loaded_shard_id is None:
1139
1140
            # Loaded weight is already fused on disk (qkv).
            # (e.g., Phi-3's qkv_proj).
1141
            if output_dim is None:
1142
                if needs_scalar_to_array:
1143
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
1144
1145
                        param_data, loaded_weight, 0
                    )
1146

1147
1148
1149
1150
1151
1152
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            shard_offsets = [
                # (shard_id, shard_offset, shard_size)
                ("q", 0, self.total_num_heads * self.head_size),
1153
1154
1155
1156
1157
1158
1159
1160
                (
                    "k",
                    self.total_num_heads * self.head_size,
                    self.total_num_kv_heads * self.head_size,
                ),
                (
                    "v",
                    (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
1161
                    self.total_num_kv_heads * self.v_head_size,
1162
                ),
1163
            ]
1164
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1165

1166
1167
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
1168
                # Special case for Quantized Weights.
1169
1170
1171
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
1172
1173
                    shard_size = shard_size // param.packed_factor
                    shard_offset = shard_offset // param.packed_factor
1174

1175
                    # Special case for Marlin.
1176
                    shard_size, shard_offset = adjust_marlin_shard(
1177
1178
                        param, shard_size, shard_offset
                    )
1179

1180
1181
1182
                if use_bitsandbytes_4bit:
                    orig_qkv_offsets = {
                        "q": (0, self.total_num_heads * self.head_size),
1183
1184
1185
1186
1187
1188
1189
                        "k": (
                            self.total_num_heads * self.head_size,
                            self.total_num_kv_heads * self.head_size,
                        ),
                        "v": (
                            (self.total_num_heads + self.total_num_kv_heads)
                            * self.head_size,
1190
                            self.total_num_kv_heads * self.v_head_size,
1191
1192
                        ),
                        "total": (
1193
1194
1195
                            (self.total_num_heads + self.total_num_kv_heads)
                            * self.head_size
                            + self.total_num_kv_heads * self.v_head_size,
1196
1197
                            0,
                        ),
1198
1199
1200
                    }

                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
1201
1202
                        param, orig_qkv_offsets, shard_id
                    )
1203

1204
                loaded_weight_shard = loaded_weight.narrow(
1205
1206
                    output_dim, shard_offset, shard_size
                )
1207
1208
1209
1210
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id in ["q", "k", "v"]
1211
1212

        # If output dim is defined, use the default loading process.
1213
1214
1215
1216
1217
1218
1219
1220
        if output_dim is not None:
            if loaded_shard_id == "q":
                shard_offset = 0
                shard_size = self.num_heads * self.head_size
            elif loaded_shard_id == "k":
                shard_offset = self.num_heads * self.head_size
                shard_size = self.num_kv_heads * self.head_size
            elif loaded_shard_id == "v":
1221
                shard_offset = (self.num_heads + self.num_kv_heads) * self.head_size
1222
                shard_size = self.num_kv_heads * self.v_head_size
1223
1224
1225
1226
1227
1228
1229

            if isinstance(param, BlockQuantScaleParameter):
                weight_block_size = getattr(self, "weight_block_size", None)
                shard_size, shard_offset = adjust_block_scale_shard(
                    weight_block_size, shard_size, shard_offset
                )

1230
            # Special case for Quantized Weights.
1231
1232
1233
1234
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
1235
1236
                shard_size = shard_size // param.packed_factor
                shard_offset = shard_offset // param.packed_factor
1237

1238
                # Special case for Marlin.
1239
                shard_size, shard_offset = adjust_marlin_shard(
1240
1241
                    param, shard_size, shard_offset
                )
1242

1243
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1244
1245
1246
1247
1248
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

1249
            if use_bitsandbytes_4bit:
1250
1251
                orig_qkv_offsets = {
                    "q": (0, self.num_heads * self.head_size),
1252
1253
1254
1255
1256
1257
                    "k": (
                        self.num_heads * self.head_size,
                        self.num_kv_heads * self.head_size,
                    ),
                    "v": (
                        (self.num_heads + self.num_kv_heads) * self.head_size,
1258
                        self.num_kv_heads * self.v_head_size,
1259
1260
                    ),
                    "total": (
1261
1262
                        (self.num_heads + self.num_kv_heads) * self.head_size
                        + self.num_kv_heads * self.v_head_size,
1263
1264
                        0,
                    ),
1265
                }
1266
                shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
1267
1268
                    param, orig_qkv_offsets, loaded_shard_id
                )
1269

1270
            param_data = param_data.narrow(output_dim, shard_offset, shard_size)
1271
            if loaded_shard_id == "q":
1272
                shard_rank = self.tp_rank
1273
            else:
1274
1275
                shard_rank = self.tp_rank // self.num_kv_head_replicas
            start_idx = shard_rank * shard_size
1276

1277
            if not is_sharded_weight:
1278
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
1279

1280
1281
1282
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
1283
1284
                param_data, loaded_weight, loaded_shard_id
            )
1285
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
1286
1287
1288
1289
1290
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "QKVParallelLinear, assume the weight is the same "
1291
1292
                    "for all partitions."
                )
1293

1294
1295
1296
1297
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)


1298
# --8<-- [start:row_parallel_linear]
1299
@CustomOp.register("row_parallel_linear")
1300
class RowParallelLinear(LinearBase):
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
    """Linear layer with row parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its first dimension and X along its second dimension as:
               -   -
              | A_1 |
              | .   |
          A = | .   |        X = [X_1, ..., X_p]
              | .   |
              | A_p |
               -   -
    Arguments:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias. Note that bias is not parallelized.
        input_is_parallel: If true, we assume that the input is already
                           split across the GPUs and we do not split
                           again.
        skip_bias_add: This was added to enable performance optimization where
                       bias can be fused with other element-wise operations.
                       We skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
1323
1324
1325
        reduce_results: If true, call all-reduce on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y = X_iA_i
1326
        quant_config: Quantization configure.
1327
1328
1329
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.down_proj)
        return_bias: If true, return bias together with outputs in forward pass.
1330
        disable_tp: If true, weights matrix won't be sharded through tp rank.
1331
1332
    """

1333
1334
    # --8<-- [end:row_parallel_linear]

1335
1336
1337
1338
1339
1340
1341
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        input_is_parallel: bool = True,
        skip_bias_add: bool = False,
1342
        params_dtype: torch.dtype | None = None,
1343
        reduce_results: bool = True,
1344
        quant_config: QuantizationConfig | None = None,
1345
1346
1347
        prefix: str = "",
        *,
        return_bias: bool = True,
1348
        disable_tp: bool = False,
1349
    ):
1350
        # Divide the weight matrix along the first dimension.
1351
1352
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
1353
1354
1355
1356
        self.input_size_per_partition = divide(input_size, self.tp_size)
        self.output_size_per_partition = output_size
        self.output_partition_sizes = [output_size]

1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
        super().__init__(
            input_size,
            output_size,
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
1367

1368
1369
1370
        self.input_is_parallel = input_is_parallel
        self.reduce_results = reduce_results

1371
        assert self.quant_method is not None
1372
1373
1374
        self.quant_method.create_weights(
            layer=self,
            input_size_per_partition=self.input_size_per_partition,
1375
            output_partition_sizes=self.output_partition_sizes,
1376
1377
1378
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
1379
            weight_loader=(
1380
1381
1382
1383
1384
                self.weight_loader_v2
                if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
                else self.weight_loader
            ),
        )
1385
        if not reduce_results and (bias and not skip_bias_add):
1386
1387
1388
1389
            raise ValueError(
                "When not reduce the results, adding bias to the "
                "results can lead to incorrect results"
            )
1390
1391

        if bias:
1392
1393
1394
1395
1396
1397
1398
1399
            self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype))
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
1400
1401
        else:
            self.register_parameter("bias", None)
1402
        self.update_param_tp_status()
1403
1404
1405

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        input_dim = getattr(param, "input_dim", None)
1406
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1407
1408
1409
1410
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
1411
1412
1413
1414
1415
1416
1417
1418
1419

        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
1420
1421
            weight_shape = list(loaded_weight.shape)
            if input_dim:
1422
                weight_shape[input_dim] = weight_shape[input_dim] // self.tp_size
1423
            param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
1424

1425
        param_data = param.data
1426
        if input_dim is not None and not is_sharded_weight:
1427
            shard_size = param_data.shape[input_dim]
1428
            start_idx = self.tp_rank * shard_size
1429
            loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)
1430

1431
1432
1433
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
1434
1435
            loaded_weight = loaded_weight.reshape(1)

1436
1437
1438
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

1439
    def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor):
1440
1441
1442
1443
1444
1445
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)

1446
1447
        param.load_row_parallel_weight(loaded_weight=loaded_weight)

1448
    def forward(
1449
1450
        self,
        input_,
1451
    ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
1452
1453
1454
1455
        if self.input_is_parallel:
            input_parallel = input_
        else:
            splitted_input = split_tensor_along_last_dim(
1456
1457
                input_, num_partitions=self.tp_size
            )
1458
            input_parallel = splitted_input[self.tp_rank].contiguous()
1459
1460

        # Matrix multiply.
1461
        assert self.quant_method is not None
1462
1463
1464
        # Only fuse bias add into GEMM for rank 0 (this ensures that
        # bias will not get added more than once in TP>1 case)
        bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
1465
1466
        output_parallel = self.quant_method.apply(self, input_parallel, bias_)

1467
        if self.reduce_results and self.tp_size > 1:
1468
            output = tensor_model_parallel_all_reduce(output_parallel)
1469
        else:
1470
1471
1472
            output = output_parallel

        output_bias = self.bias if self.skip_bias_add else None
1473

1474
1475
        if not self.return_bias:
            return output
1476
        return output, output_bias
1477
1478

    def extra_repr(self) -> str:
1479
        s = f"in_features={self.input_size_per_partition}"
1480
1481
1482
1483
1484
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={self.tp_size}"
        s += f", reduce_results={self.reduce_results}"
        return s