linear.py 57.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
from abc import abstractmethod
6
from typing import Any
7
8

import torch
9
from torch.nn.parameter import Parameter, UninitializedParameter
10

11
12
13
14
15
16
17
18
from vllm.distributed import (
    divide,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
    split_tensor_along_last_dim,
    tensor_model_parallel_all_gather,
    tensor_model_parallel_all_reduce,
)
19
from vllm.logger import init_logger
20
from vllm.model_executor.custom_op import CustomOp
21
from vllm.model_executor.layers.quantization.base_config import (
22
23
24
    QuantizationConfig,
    QuantizeMethodBase,
)
25
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
26
27
28
29
30
31
32
33
34
from vllm.model_executor.parameter import (
    BasevLLMParameter,
    BlockQuantScaleParameter,
    ModelWeightParameter,
    PackedColumnParameter,
    PackedvLLMParameter,
    PerTensorScaleParameter,
    RowvLLMParameter,
)
35
from vllm.model_executor.utils import set_weight_attrs
36
from vllm.platforms import current_platform
37
38
39

logger = init_logger(__name__)

40
WEIGHT_LOADER_V2_SUPPORTED = [
41
    "UnquantizedLinearMethod",
42
    "CompressedTensorsLinearMethod",
43
    "CompressedTensorsLinearTransformMethod",
44
45
    "BitBLASLinearMethod",
    "GPTQBitBLASLinearMethod",
46
47
48
49
50
51
52
53
54
55
    "AWQMarlinLinearMethod",
    "AWQLinearMethod",
    "GPTQMarlinLinearMethod",
    "Fp8LinearMethod",
    "MarlinLinearMethod",
    "GPTQMarlin24LinearMethod",
    "TPUInt8LinearMethod",
    "GPTQLinearMethod",
    "FBGEMMFp8LinearMethod",
    "ModelOptFp8LinearMethod",
56
57
    "ModelOptFp8PcPtLinearMethod",
    "ModelOptFp8PbWoLinearMethod",
58
59
60
61
62
    "IPEXAWQLinearMethod",
    "IPEXGPTQLinearMethod",
    "HQQMarlinMethod",
    "QuarkLinearMethod",
    "ModelOptNvFp4LinearMethod",
63
    "PetitNvFp4LinearMethod",
64
]
65

66

67
68
69
def adjust_bitblas_shard(param, shard_size, shard_offset):
    bitblas_tile_size = getattr(param, "bitblas_tile_size", None)
    if bitblas_tile_size is not None:
70
        return (shard_size // bitblas_tile_size, shard_offset // bitblas_tile_size)
71
72
73
74

    return shard_size, shard_offset


75
76
77
78
79
80
81
82
def adjust_marlin_shard(param, shard_size, shard_offset):
    marlin_tile_size = getattr(param, "marlin_tile_size", None)
    if marlin_tile_size is None:
        return shard_size, shard_offset

    return shard_size * marlin_tile_size, shard_offset * marlin_tile_size


83
84
85
86
87
88
89
90
def adjust_block_scale_shard(weight_block_size, shard_size, shard_offset):
    assert weight_block_size is not None
    block_n = weight_block_size[0]
    shard_offset = (shard_offset + block_n - 1) // block_n
    shard_size = (shard_size + block_n - 1) // block_n
    return shard_size, shard_offset


91
92
93
def adjust_bitsandbytes_4bit_shard(
    param: Parameter, shard_offsets: dict[str, tuple[int, int]], loaded_shard_id: str
) -> tuple[int, int]:
94
95
    """Adjust the quantization offsets and sizes for BitsAndBytes sharding."""

96
97
    total, _ = shard_offsets["total"]
    orig_offset, orig_size = shard_offsets[loaded_shard_id]
98
99
100
101
102
103
104
105

    quantized_total = param.data.shape[0]
    quantized_offset = orig_offset * quantized_total // total
    quantized_size = orig_size * quantized_total // total

    return quantized_size, quantized_offset


106
107
108
def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
    """For fused modules (QKV and MLP) we have an array of length
    N that holds 1 scale for each "logical" matrix. So the param
109
110
    is an array of length N. The loaded_weight corresponds to
    one of the shards on disk. Here, we slice the param based on
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
    the shard_id for loading.
    """
    qkv_idxs = {"q": 0, "k": 1, "v": 2}

    if isinstance(shard_id, str):
        shard_id = qkv_idxs[shard_id]
    elif not isinstance(shard_id, int):
        raise ValueError(f"Unknown Shard Id {shard_id}")

    # AutoFP8 scales do not have a shape
    # compressed-tensors scales do have a shape
    if len(loaded_weight.shape) != 0:
        assert loaded_weight.shape[0] == 1
        loaded_weight = loaded_weight[0]

    return param[shard_id], loaded_weight


129
130
131
132
133
134
135
136
# TODO(Isotr0py): We might need a more flexible structure to handle
# bitsandbytes shard offsets.
def left_shift_bitsandbytes_4bit_shard(bnb_weight_attrs: dict[str, Any]):
    """
    Separate the BitsAndBytes 4-bit shard.

    For example, given bnb weight attributes as below:
    {
137
        'bnb_shard_offsets': array([0, 4, 8, 16]),
138
139
140
141
142
        'bnb_quant_state': {0: ..., 1: ..., 2: ...},
    }

    The function will return:
    {
143
        'bnb_shard_offsets': array([0, 4]),
144
145
146
147
148
149
150
151
152
153
154
155
156
157
        'bnb_quant_state': {0: ...},
    }
    and
    {
        'bnb_shard_offsets': array([0, 4, 12]),
        'bnb_quant_state': {0: ..., 1: ...},
    }
    """
    shard_offsets = bnb_weight_attrs["bnb_shard_offsets"]
    offset_l = shard_offsets[:2]
    offset_r = shard_offsets[1:] - shard_offsets[1]
    quant_state_l = {0: bnb_weight_attrs["bnb_quant_state"][0]}
    quant_state_r = {
        i - 1: bnb_weight_attrs["bnb_quant_state"][i]
158
        for i in range(1, len(shard_offsets) - 1)
159
160
161
162
163
164
    }
    left = dict(bnb_shard_offsets=offset_l, bnb_quant_state=quant_state_l)
    right = dict(bnb_shard_offsets=offset_r, bnb_quant_state=quant_state_r)
    return left, right


165
class LinearMethodBase(QuantizeMethodBase):
166
167
168
    """Base class for different (maybe quantized) linear methods."""

    @abstractmethod
169
170
171
172
173
174
175
176
177
178
179
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        """Create weights for a linear layer.
180
           The weights will be set as attributes of the layer.
181

182
183
184
        Args:
            layer: The layer that is using the LinearMethodBase factory.
            input_size_per_partition: Size of the weight input dim on rank X.
185
            output_partition_sizes: Sizes of the output dim of each logical
186
187
188
189
190
191
                weight on rank X. E.g., output_partition_sizes for QKVLinear
                is a list contains the width of Wq, Wk, Wv on rank X.
            input_size: Size of the input dim of the weight across all ranks.
            output_size: Size of the output dim of the weight across all ranks.
            params_dtype: Datatype of the parameters.
        """
192
193
194
        raise NotImplementedError

    @abstractmethod
195
196
197
198
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
199
        bias: torch.Tensor | None = None,
200
    ) -> torch.Tensor:
201
202
        """Apply the weights in layer to the input tensor.
        Expects create_weights to have been called before on the layer."""
203
204
205
206
        raise NotImplementedError


class UnquantizedLinearMethod(LinearMethodBase):
207
    """Linear method without quantization."""
208

209
210
211
212
213
214
215
216
217
218
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
219
220
221
222
        # This method creates unquantized linear weights.
        # The weights are not quantized, and they are not sharded.
        # The amount of memory allocated for the weights is
        # sum(output_partition_sizes) * input_size_per_partition.
223
224
225
226
227
228
229
230
231
232
233
        weight_loader = extra_weight_attrs.pop("weight_loader")
        weight = ModelWeightParameter(
            data=torch.empty(
                sum(output_partition_sizes),
                input_size_per_partition,
                dtype=params_dtype,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
234

235
236
        layer.register_parameter("weight", weight)
        set_weight_attrs(weight, extra_weight_attrs)
237

238
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
239
        if current_platform.is_cpu():
240
            from vllm.model_executor.layers.utils import dispatch_cpu_unquantized_gemm
241

242
            dispatch_cpu_unquantized_gemm(layer, remove_weight=True)
243

244
245
246
247
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
248
        bias: torch.Tensor | None = None,
249
    ) -> torch.Tensor:
250
        return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
251
252


253
class LinearBase(CustomOp):
254
    """Base linear layer.
255
256
257
258
259
260

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
261
        quant_config: Quantization configure.
262
        prefix: Prefix for parameter names.
263
        return_bias: If true, return bias together with outputs in forward pass.
264
        disable_tp: If true, tensor parallelism will be disabled for this layer.
265
266
267
268
269
270
271
    """

    def __init__(
        self,
        input_size: int,
        output_size: int,
        skip_bias_add: bool = False,
272
273
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
274
        prefix: str = "",
275
276
        *,
        return_bias: bool = True,
277
        disable_tp: bool = False,
278
279
280
281
282
283
284
285
286
287
    ):
        super().__init__()

        # Keep input parameters
        self.input_size = input_size
        self.output_size = output_size
        self.skip_bias_add = skip_bias_add
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype
288
289
        self.quant_config = quant_config
        self.prefix = prefix
290
        self.allow_fp8_block_shape_mismatch = False
291
        if quant_config is None:
292
            self.quant_method: QuantizeMethodBase | None = UnquantizedLinearMethod()
293
        else:
294
            self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
295
        self.return_bias = return_bias
296
        self.disable_tp = disable_tp
297
298
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
299

300
    def update_param_tp_status(self):
301
302
303
304
        for param in self.parameters():
            if isinstance(param, BasevLLMParameter):
                param.tp_rank = self.tp_rank
                param.tp_size = self.tp_size
305
306


307
# --8<-- [start:replicated_linear]
308
@CustomOp.register("replicated_linear")
309
310
311
312
313
314
315
316
317
318
class ReplicatedLinear(LinearBase):
    """Replicated linear layer.

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        bias: If true, add bias.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
        quant_config: Quantization configure.
319
320
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
321
        return_bias: If true, return bias together with outputs in forward pass.
322
        disable_tp: Take no effect for replicated linear layers.
323
324
    """

325
326
    # --8<-- [end:replicated_linear]

327
328
329
330
331
332
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        skip_bias_add: bool = False,
333
334
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
335
336
337
        prefix: str = "",
        *,
        return_bias: bool = True,
338
        disable_tp: bool = False,
339
    ):
340
341
342
343
344
345
        # If MergedReplicatedLinear, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = self.output_sizes
        else:
            self.output_partition_sizes = [output_size]

346
347
348
349
350
351
352
353
354
355
        super().__init__(
            input_size,
            output_size,
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
356

357
358
        # All the linear layer supports quant method.
        assert self.quant_method is not None
359
360
361
362
363
364
365
366
367
        self.quant_method.create_weights(
            self,
            self.input_size,
            self.output_partition_sizes,
            self.input_size,
            self.output_size,
            self.params_dtype,
            weight_loader=self.weight_loader,
        )
368

369
370
        if bias:
            self.bias = Parameter(
371
372
373
374
375
376
377
378
379
                torch.empty(self.output_size, dtype=self.params_dtype)
            )
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
380
381
382
        else:
            self.register_parameter("bias", None)

383
384
385
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        # If the weight on disk does not have a shape, give it one
        # (such scales for AutoFp8).
386
387
388
389
390
391
392
393
394
395
396
        # Special case for GGUF

        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
            param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)

397
398
399
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)

400
401
        assert param.size() == loaded_weight.size(), (
            f"Tried to load weights of size {loaded_weight.size()}"
402
403
            f"to a parameter of size {param.size()}"
        )
404
405
        param.data.copy_(loaded_weight)

406
    def forward(
407
408
        self,
        x: torch.Tensor,
409
    ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
410
        bias = self.bias if not self.skip_bias_add else None
411
        assert self.quant_method is not None
412

413
        output = self.quant_method.apply(self, x, bias)
414

415
416
        if not self.return_bias:
            return output
417
        output_bias = self.bias if self.skip_bias_add else None
418
419
        return output, output_bias

420
421
422
423
424
425
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        return s

426

427
# --8<-- [start:column_parallel_linear]
428
@CustomOp.register("column_parallel_linear")
429
class ColumnParallelLinear(LinearBase):
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
    """Linear layer with column parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its second dimension as A = [A_1, ..., A_p].

    Args:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y_i = XA_i
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
446
        quant_config: Quantization configure.
447
        prefix: The name of the layer in the state dict, including all parents
448
                        (e.g. model.layers.0.qkv_proj)
449
450
        return_bias: If true, return bias together with outputs in forward pass.
        disable_tp: If true, weights matrix won't be sharded through tp rank.
451
452
    """

453
454
    # --8<-- [end:column_parallel_linear]

455
456
457
458
459
460
461
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
462
463
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
464
465
466
        prefix: str = "",
        *,
        return_bias: bool = True,
467
        disable_tp: bool = False,
468
    ):
469
        # Divide the weight matrix along the last dimension.
470
471
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
472
473
        self.input_size_per_partition = input_size
        self.output_size_per_partition = divide(output_size, self.tp_size)
474
475
476
477
        self.output_partition_sizes = [self.output_size_per_partition]
        # If QKV or MergedColumn, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = [
478
                divide(output_size, self.tp_size) for output_size in self.output_sizes
479
480
            ]

481
482
483
484
485
486
487
488
489
490
        super().__init__(
            input_size,
            output_size,
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
491

492
        self._maybe_allow_fp8_block_shape_mismatch()
493
494
495
        self.gather_output = gather_output

        assert self.quant_method is not None
496
497
        self.quant_method.create_weights(
            layer=self,
498
            input_size_per_partition=self.input_size_per_partition,
499
500
501
502
            output_partition_sizes=self.output_partition_sizes,
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
503
            weight_loader=(
504
505
506
507
508
                self.weight_loader_v2
                if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
                else self.weight_loader
            ),
        )
509
510
        if bias:
            self.bias = Parameter(
511
512
513
514
515
516
517
518
519
                torch.empty(self.output_size_per_partition, dtype=params_dtype)
            )
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
520
521
        else:
            self.register_parameter("bias", None)
522
        self.update_param_tp_status()
523

524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
    def _maybe_allow_fp8_block_shape_mismatch(self) -> None:
        quant_config = getattr(self, "quant_config", None)
        weight_block = getattr(quant_config, "weight_block_size", None)
        if (
            weight_block is None
            or len(weight_block) < 1
            or len(self.output_partition_sizes) <= 1
        ):
            return

        try:
            block_n = int(weight_block[0])
        except (ValueError, TypeError):
            return

        if block_n <= 0:
            return

        if any(size % block_n != 0 for size in self.output_partition_sizes):
            self.allow_fp8_block_shape_mismatch = True
            logger.debug(
                "Allowing FP8 block shape mismatch for %s (block_n=%d, partitions=%s)",
                getattr(self, "prefix", "<unknown>"),
                block_n,
                self.output_partition_sizes,
            )

551
552
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        output_dim = getattr(param, "output_dim", None)
553

554
555
556
557
558
559
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

560
561
562
563
564
565
566
567
        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
568
569
            final_shape = list(loaded_weight.shape)
            if output_dim is not None:
570
                assert final_shape[output_dim] % self.tp_size == 0
571
                final_shape[output_dim] = final_shape[output_dim] // self.tp_size
572
            param.materialize(final_shape, dtype=loaded_weight.dtype)
573

574
        param_data = param.data
575
        if output_dim is not None and not is_sharded_weight:
576
            shard_size = param_data.shape[output_dim]
577
            start_idx = self.tp_rank * shard_size
578
            loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
579
580
581
582
583

        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)
584

585
586
587
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

588
    def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor):
589
590
591
592
593
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)
594
595
        param.load_column_parallel_weight(loaded_weight=loaded_weight)

596
    def forward(
597
598
        self,
        input_,
599
    ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
600
601
602
        bias = self.bias if not self.skip_bias_add else None

        # Matrix multiply.
603
        assert self.quant_method is not None
604
        output_parallel = self.quant_method.apply(self, input_, bias)
605

606
        if self.gather_output and self.tp_size > 1:
607
608
609
610
            # All-gather across the partitions.
            output = tensor_model_parallel_all_gather(output_parallel)
        else:
            output = output_parallel
611

612
613
        if not self.return_bias:
            return output
614
        output_bias = self.bias if self.skip_bias_add else None
615
616
        return output, output_bias

617
618
619
620
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size_per_partition}"
        s += f", bias={self.bias is not None}"
621
        s += f", tp_size={self.tp_size}"
622
623
624
        s += f", gather_output={self.gather_output}"
        return s

625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643

class MergedColumnParallelLinear(ColumnParallelLinear):
    """Packed linear layers with column parallelism.

    Similar to ColumnParallelLinear, but the weight matrix is concatenated
    along the output dimension. When the weight matrix is loaded, the
    different partitions are sharded separately.

    Args:
        input_size: input dimension of the linear layer.
        output_sizes: list of output dimensions of the linear layer.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make the output
                       available to all GPUs, otherwise, every GPU will have
                       its own output.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
644
        quant_config: Quantization configure.
645
646
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
647
        return_bias: If true, return bias together with outputs in forward pass.
648
649
        disable_tp: If true, all weights matrix won't be sharded, this layer
                    will be treated as a "Replicated" MergedLinear.
650
651
    """

652
653
654
655
656
657
658
    def __init__(
        self,
        input_size: int,
        output_sizes: list[int],
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
659
660
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
661
662
663
        prefix: str = "",
        *,
        return_bias: bool = True,
664
        disable_tp: bool = False,
665
    ):
666
        self.output_sizes = output_sizes
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0

        assert all(output_size % self.tp_size == 0 for output_size in output_sizes)
        super().__init__(
            input_size=input_size,
            output_size=sum(output_sizes),
            bias=bias,
            gather_output=gather_output,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            quant_config=quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
James Fleming's avatar
James Fleming committed
683

684
685
686
687
    def weight_loader(
        self,
        param: Parameter,
        loaded_weight: torch.Tensor,
688
        loaded_shard_id: int | None = None,
689
    ):
690
691
692
693
694
        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
695
696
697
698
699
            if loaded_shard_id is not None:
                param.data[loaded_shard_id].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
                param.shard_weight_type = {
700
                    i: loaded_weight.item() for i, _ in enumerate(self.output_sizes)
701
                }
702
703
            return

704
705
        if is_gguf_weight:
            output_dim = getattr(param, "output_dim", None)
706
707
            shard_size = loaded_weight.size(output_dim) // self.tp_size
            start_idx = self.tp_rank * shard_size
708

709
            if loaded_shard_id is not None:
710
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
711
712
713
714
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
715

716
717
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
718
719
        # Special case for per-tensor scale to load scalar into fused array.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
720

721
        if loaded_shard_id is None:
722
723
            # Loaded weight is already fused on disk (mlp).
            # (e.g., Phi-3's gate_up_proj).
724
            if output_dim is None:
725
                if needs_scalar_to_array:
726
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
727
728
                        param_data, loaded_weight, 0
                    )
729

730
731
732
733
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            current_shard_offset = 0
734
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
735
            shard_offsets: list[tuple[int, int, int]] = []
736
737
738
739
740
            for i, output_size in enumerate(self.output_sizes):
                shard_offsets.append((i, current_shard_offset, output_size))
                current_shard_offset += output_size
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
741
                # Special case for Quantization.
742
743
744
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
745
746
                    shard_size = shard_size // param.packed_factor
                    shard_offset = shard_offset // param.packed_factor
747
                    # Special case for Marlin.
748
                    shard_size, shard_offset = adjust_marlin_shard(
749
750
                        param, shard_size, shard_offset
                    )
751

752
                shard_size, shard_offset = adjust_bitblas_shard(
753
754
                    param, shard_size, shard_offset
                )
755

756
                if use_bitsandbytes_4bit:
757
758
759
760
761
762
763
                    index = list(itertools.accumulate([0] + self.output_sizes))
                    orig_offsets = {
                        str(i): (index[i], size)
                        for i, size in enumerate(self.output_sizes)
                    }
                    orig_offsets["total"] = (self.output_size, 0)
                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
764
765
                        param, orig_offsets, str(shard_id)
                    )
766

767
                loaded_weight_shard = loaded_weight.narrow(
768
769
                    output_dim, shard_offset, shard_size
                )
770
771
772
773
774
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id < len(self.output_sizes)
        if output_dim is not None:
775
776
777
778
779
780
781
782
783
784
785
786
            shard_offset = sum(self.output_sizes[:loaded_shard_id])
            shard_size = self.output_sizes[loaded_shard_id]

            if isinstance(param, BlockQuantScaleParameter):
                weight_block_size = getattr(self, "weight_block_size", None)
                shard_size, shard_offset = adjust_block_scale_shard(
                    weight_block_size, shard_size, shard_offset
                )

            shard_offset //= self.tp_size
            shard_size //= self.tp_size

787
            # Special case for quantization.
788
789
790
791
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
792
793
                shard_size = shard_size // param.packed_factor
                shard_offset = shard_offset // param.packed_factor
794
                # Special case for Marlin.
795
                shard_size, shard_offset = adjust_marlin_shard(
796
797
                    param, shard_size, shard_offset
                )
798
            shard_size, shard_offset = adjust_bitblas_shard(
799
800
                param, shard_size, shard_offset
            )
801

802
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
803
804
805
806
807
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

808
            if use_bitsandbytes_4bit:
809
                shard_size = loaded_weight.shape[output_dim]
810
                shard_offset = loaded_weight.shape[output_dim] * loaded_shard_id
811

812
            param_data = param_data.narrow(output_dim, shard_offset, shard_size)
813
            start_idx = self.tp_rank * shard_size
814
            if not is_sharded_weight:
815
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
816
817
818
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
819
820
                param_data, loaded_weight, loaded_shard_id
            )
821

822
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
823
824
825
826
827
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "MergedColumnParallelLinear, assume the weight is "
828
829
                    "the same for all partitions."
                )
830

831
832
833
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

834
835
836
    def _load_fused_module_from_checkpoint(
        self, param: BasevLLMParameter, loaded_weight: torch.Tensor
    ):
837
838
839
        """
        Handle special case for models where MLP layers are already
        fused on disk. In this case, we have no shard id. This function
840
        determines the shard id by splitting these layers and then calls
841
842
843
844
845
846
847
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """

        current_shard_offset = 0
848
        shard_offsets: list[tuple[int, int, int]] = []
849
850
851
852
853
854
855
856
        for i, output_size in enumerate(self.output_sizes):
            shard_offsets.append((i, current_shard_offset, output_size))
            current_shard_offset += output_size

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
857
858
859
860
861
862
863
864
865
866
867
            if (
                isinstance(param, (PackedColumnParameter, PackedvLLMParameter))
                and param.packed_dim == param.output_dim
            ):
                shard_size, shard_offset = param.adjust_shard_indexes_for_packing(
                    shard_size=shard_size, shard_offset=shard_offset
                )

            loaded_weight_shard = loaded_weight.narrow(
                param.output_dim, shard_offset, shard_size
            )
868
869
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

870
871
872
873
    def weight_loader_v2(
        self,
        param: BasevLLMParameter,
        loaded_weight: torch.Tensor,
874
        loaded_shard_id: int | None = None,
875
    ):
876
        if loaded_shard_id is None:
877
            if isinstance(param, PerTensorScaleParameter):
878
                param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0)
879
                return
880
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
881
                param.load_merged_column_weight(loaded_weight=loaded_weight)
882
                return
883
            # TODO: @dsikka - move to parameter.py
884
885
886
887
888
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id < len(self.output_sizes)

889
890
891
        shard_offset = sum(self.output_sizes[:loaded_shard_id])
        shard_size = self.output_sizes[loaded_shard_id]

892
        if isinstance(param, BlockQuantScaleParameter):
893
894
895
            weight_block_size = getattr(self, "weight_block_size", None)
            shard_size, shard_offset = adjust_block_scale_shard(
                weight_block_size, shard_size, shard_offset
896
            )
897
898
899

        shard_offset //= self.tp_size
        shard_size //= self.tp_size
900

901
902
903
904
905
906
907
        param.load_merged_column_weight(
            loaded_weight=loaded_weight,
            shard_id=loaded_shard_id,
            shard_offset=shard_offset,
            shard_size=shard_size,
            tp_rank=self.tp_rank,
        )
908

909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930

class QKVParallelLinear(ColumnParallelLinear):
    """Linear layers for the attention's QKV transformation.

    Linear layers for the linear transformation of the query, key, and value
    vectors in the attention layer. The weight matrix is concatenated along
    the output dimension. The layer is parallelized along the head dimension.
    When the number of key/value heads is smaller than the number of query
    heads (e.g., multi-query/grouped-query attention), the key/value head may
    be replicated while the query heads are partitioned.

    Args:
        hidden_size: input hidden state size of the transformer.
        head_size: size of each attention head.
        total_num_heads: total number of attention query heads.
        total_num_kv_heads: total number of attention key/value heads. If
                            None, assume total_num_kv_heads = total_num_heads.
        bias: If true, add bias.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
931
        quant_config: Quantization configure.
932
933
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
934
        return_bias: If true, return bias together with outputs in forward pass.
935
        disable_tp: If true, weights matrix won't be sharded through tp rank.
936
937
    """

938
939
940
941
942
    def __init__(
        self,
        hidden_size: int,
        head_size: int,
        total_num_heads: int,
943
        total_num_kv_heads: int | None = None,
944
945
        bias: bool = True,
        skip_bias_add: bool = False,
946
947
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
948
949
950
        prefix: str = "",
        *,
        return_bias: bool = True,
951
        disable_tp: bool = False,
952
        v_head_size: int | None = None,
953
    ):
954
955
        self.hidden_size = hidden_size
        self.head_size = head_size
956
        self.v_head_size = v_head_size if v_head_size is not None else head_size
957
958
959
960
961
        self.total_num_heads = total_num_heads
        if total_num_kv_heads is None:
            total_num_kv_heads = total_num_heads
        self.total_num_kv_heads = total_num_kv_heads
        # Divide the weight matrix along the last dimension.
962
        tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
963
964
965
        self.num_heads = divide(self.total_num_heads, tp_size)
        if tp_size >= self.total_num_kv_heads:
            self.num_kv_heads = 1
966
            self.num_kv_head_replicas = divide(tp_size, self.total_num_kv_heads)
967
968
969
970
        else:
            self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
            self.num_kv_head_replicas = 1
        input_size = self.hidden_size
971
        output_size = (
972
973
974
975
            self.num_heads * self.head_size
            + self.num_kv_heads * self.head_size
            + self.num_kv_heads * self.v_head_size
        ) * tp_size
976
977
978
        self.output_sizes = [
            self.num_heads * self.head_size * tp_size,  # q_proj
            self.num_kv_heads * self.head_size * tp_size,  # k_proj
979
            self.num_kv_heads * self.v_head_size * tp_size,  # v_proj
James Fleming's avatar
James Fleming committed
980
981
        ]

982
983
984
985
986
987
988
989
990
991
992
993
        super().__init__(
            input_size=input_size,
            output_size=output_size,
            bias=bias,
            gather_output=False,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            quant_config=quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
994

995
996
997
998
999
    def _get_shard_offset_mapping(self, loaded_shard_id: str):
        shard_offset_mapping = {
            "q": 0,
            "k": self.num_heads * self.head_size,
            "v": (self.num_heads + self.num_kv_heads) * self.head_size,
1000
1001
            "total": (self.num_heads + self.num_kv_heads) * self.head_size
            + self.num_kv_heads * self.v_head_size,
1002
1003
1004
1005
1006
1007
1008
        }
        return shard_offset_mapping.get(loaded_shard_id)

    def _get_shard_size_mapping(self, loaded_shard_id: str):
        shard_size_mapping = {
            "q": self.num_heads * self.head_size,
            "k": self.num_kv_heads * self.head_size,
1009
            "v": self.num_kv_heads * self.v_head_size,
1010
1011
1012
        }
        return shard_size_mapping.get(loaded_shard_id)

1013
1014
1015
    def _load_fused_module_from_checkpoint(
        self, param: BasevLLMParameter, loaded_weight: torch.Tensor
    ):
1016
        """
1017
        Handle special case for models where QKV layers are already
1018
        fused on disk. In this case, we have no shard id. This function
1019
        determines the shard id by splitting these layers and then calls
1020
1021
1022
1023
1024
1025
1026
1027
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """
        shard_offsets = [
            # (shard_id, shard_offset, shard_size)
            ("q", 0, self.total_num_heads * self.head_size),
1028
1029
1030
1031
1032
1033
1034
1035
            (
                "k",
                self.total_num_heads * self.head_size,
                self.total_num_kv_heads * self.head_size,
            ),
            (
                "v",
                (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
1036
                self.total_num_kv_heads * self.v_head_size,
1037
            ),
1038
1039
1040
1041
1042
1043
        ]

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
            if (
                isinstance(param, (PackedColumnParameter, PackedvLLMParameter))
                and param.packed_dim == param.output_dim
            ):
                shard_size, shard_offset = param.adjust_shard_indexes_for_packing(
                    shard_size=shard_size, shard_offset=shard_offset
                )

            loaded_weight_shard = loaded_weight.narrow(
                param.output_dim, shard_offset, shard_size
            )
1055
1056
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

1057
1058
1059
1060
    def weight_loader_v2(
        self,
        param: BasevLLMParameter,
        loaded_weight: torch.Tensor,
1061
        loaded_shard_id: str | None = None,
1062
    ):
1063
        if loaded_shard_id is None:  # special case for certain models
1064
            if isinstance(param, PerTensorScaleParameter):
1065
1066
1067
                param.load_qkv_weight(
                    loaded_weight=loaded_weight, shard_id=0, tp_rank=self.tp_rank
                )
1068
                return
1069
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
1070
                param.load_qkv_weight(loaded_weight=loaded_weight, tp_rank=self.tp_rank)
1071
                return
1072
            # TODO: @dsikka - move to parameter.py
1073
1074
1075
1076
1077
1078
1079
1080
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id in ["q", "k", "v"]

        shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
        shard_size = self._get_shard_size_mapping(loaded_shard_id)

1081
        if isinstance(param, BlockQuantScaleParameter):
1082
1083
1084
1085
            weight_block_size = getattr(self, "weight_block_size", None)
            shard_size, shard_offset = adjust_block_scale_shard(
                weight_block_size, shard_size, shard_offset
            )
1086

1087
1088
1089
1090
1091
1092
1093
1094
        param.load_qkv_weight(
            loaded_weight=loaded_weight,
            num_heads=self.num_kv_head_replicas,
            shard_id=loaded_shard_id,
            shard_offset=shard_offset,
            shard_size=shard_size,
            tp_rank=self.tp_rank,
        )
1095

1096
1097
1098
1099
    def weight_loader(
        self,
        param: Parameter,
        loaded_weight: torch.Tensor,
1100
        loaded_shard_id: str | None = None,
1101
    ):
1102
1103
1104
1105
        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
1106
        if is_gguf_weight_type:
1107
            idx_map = {"q": 0, "k": 1, "v": 2}
1108
1109
1110
1111
            if loaded_shard_id is not None:
                param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
1112
                param.shard_weight_type = {k: loaded_weight.item() for k in idx_map}
1113
1114
            return

1115
1116
        if is_gguf_weight:
            output_dim = getattr(param, "output_dim", None)
1117
1118
            shard_size = loaded_weight.size(output_dim) // self.tp_size
            start_idx = self.tp_rank * shard_size
1119

1120
            if loaded_shard_id is not None:
1121
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
1122
1123
1124
1125
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
1126

1127
1128
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
1129

1130
1131
        # Special case for per-tensor scales in fused case.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
1132

1133
        if loaded_shard_id is None:
1134
1135
            # Loaded weight is already fused on disk (qkv).
            # (e.g., Phi-3's qkv_proj).
1136
            if output_dim is None:
1137
                if needs_scalar_to_array:
1138
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
1139
1140
                        param_data, loaded_weight, 0
                    )
1141

1142
1143
1144
1145
1146
1147
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            shard_offsets = [
                # (shard_id, shard_offset, shard_size)
                ("q", 0, self.total_num_heads * self.head_size),
1148
1149
1150
1151
1152
1153
1154
1155
                (
                    "k",
                    self.total_num_heads * self.head_size,
                    self.total_num_kv_heads * self.head_size,
                ),
                (
                    "v",
                    (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
1156
                    self.total_num_kv_heads * self.v_head_size,
1157
                ),
1158
            ]
1159
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1160

1161
1162
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
1163
                # Special case for Quantized Weights.
1164
1165
1166
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
1167
1168
                    shard_size = shard_size // param.packed_factor
                    shard_offset = shard_offset // param.packed_factor
1169

1170
                    # Special case for Marlin.
1171
                    shard_size, shard_offset = adjust_marlin_shard(
1172
1173
                        param, shard_size, shard_offset
                    )
1174

1175
1176
1177
                if use_bitsandbytes_4bit:
                    orig_qkv_offsets = {
                        "q": (0, self.total_num_heads * self.head_size),
1178
1179
1180
1181
1182
1183
1184
                        "k": (
                            self.total_num_heads * self.head_size,
                            self.total_num_kv_heads * self.head_size,
                        ),
                        "v": (
                            (self.total_num_heads + self.total_num_kv_heads)
                            * self.head_size,
1185
                            self.total_num_kv_heads * self.v_head_size,
1186
1187
                        ),
                        "total": (
1188
1189
1190
                            (self.total_num_heads + self.total_num_kv_heads)
                            * self.head_size
                            + self.total_num_kv_heads * self.v_head_size,
1191
1192
                            0,
                        ),
1193
1194
1195
                    }

                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
1196
1197
                        param, orig_qkv_offsets, shard_id
                    )
1198

1199
                loaded_weight_shard = loaded_weight.narrow(
1200
1201
                    output_dim, shard_offset, shard_size
                )
1202
1203
1204
1205
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id in ["q", "k", "v"]
1206
1207

        # If output dim is defined, use the default loading process.
1208
1209
1210
1211
1212
1213
1214
1215
        if output_dim is not None:
            if loaded_shard_id == "q":
                shard_offset = 0
                shard_size = self.num_heads * self.head_size
            elif loaded_shard_id == "k":
                shard_offset = self.num_heads * self.head_size
                shard_size = self.num_kv_heads * self.head_size
            elif loaded_shard_id == "v":
1216
                shard_offset = (self.num_heads + self.num_kv_heads) * self.head_size
1217
                shard_size = self.num_kv_heads * self.v_head_size
1218
1219
1220
1221
1222
1223
1224

            if isinstance(param, BlockQuantScaleParameter):
                weight_block_size = getattr(self, "weight_block_size", None)
                shard_size, shard_offset = adjust_block_scale_shard(
                    weight_block_size, shard_size, shard_offset
                )

1225
            # Special case for Quantized Weights.
1226
1227
1228
1229
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
1230
1231
                shard_size = shard_size // param.packed_factor
                shard_offset = shard_offset // param.packed_factor
1232

1233
                # Special case for Marlin.
1234
                shard_size, shard_offset = adjust_marlin_shard(
1235
1236
                    param, shard_size, shard_offset
                )
1237

1238
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1239
1240
1241
1242
1243
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

1244
            if use_bitsandbytes_4bit:
1245
1246
                orig_qkv_offsets = {
                    "q": (0, self.num_heads * self.head_size),
1247
1248
1249
1250
1251
1252
                    "k": (
                        self.num_heads * self.head_size,
                        self.num_kv_heads * self.head_size,
                    ),
                    "v": (
                        (self.num_heads + self.num_kv_heads) * self.head_size,
1253
                        self.num_kv_heads * self.v_head_size,
1254
1255
                    ),
                    "total": (
1256
1257
                        (self.num_heads + self.num_kv_heads) * self.head_size
                        + self.num_kv_heads * self.v_head_size,
1258
1259
                        0,
                    ),
1260
                }
1261
                shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
1262
1263
                    param, orig_qkv_offsets, loaded_shard_id
                )
1264

1265
            param_data = param_data.narrow(output_dim, shard_offset, shard_size)
1266
            if loaded_shard_id == "q":
1267
                shard_rank = self.tp_rank
1268
            else:
1269
1270
                shard_rank = self.tp_rank // self.num_kv_head_replicas
            start_idx = shard_rank * shard_size
1271

1272
            if not is_sharded_weight:
1273
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
1274

1275
1276
1277
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
1278
1279
                param_data, loaded_weight, loaded_shard_id
            )
1280
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
1281
1282
1283
1284
1285
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "QKVParallelLinear, assume the weight is the same "
1286
1287
                    "for all partitions."
                )
1288

1289
1290
1291
1292
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)


1293
# --8<-- [start:row_parallel_linear]
1294
@CustomOp.register("row_parallel_linear")
1295
class RowParallelLinear(LinearBase):
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
    """Linear layer with row parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its first dimension and X along its second dimension as:
               -   -
              | A_1 |
              | .   |
          A = | .   |        X = [X_1, ..., X_p]
              | .   |
              | A_p |
               -   -
    Arguments:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias. Note that bias is not parallelized.
        input_is_parallel: If true, we assume that the input is already
                           split across the GPUs and we do not split
                           again.
        skip_bias_add: This was added to enable performance optimization where
                       bias can be fused with other element-wise operations.
                       We skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
1318
1319
1320
        reduce_results: If true, call all-reduce on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y = X_iA_i
1321
        quant_config: Quantization configure.
1322
1323
1324
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.down_proj)
        return_bias: If true, return bias together with outputs in forward pass.
1325
        disable_tp: If true, weights matrix won't be sharded through tp rank.
1326
1327
    """

1328
1329
    # --8<-- [end:row_parallel_linear]

1330
1331
1332
1333
1334
1335
1336
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        input_is_parallel: bool = True,
        skip_bias_add: bool = False,
1337
        params_dtype: torch.dtype | None = None,
1338
        reduce_results: bool = True,
1339
        quant_config: QuantizationConfig | None = None,
1340
1341
1342
        prefix: str = "",
        *,
        return_bias: bool = True,
1343
        disable_tp: bool = False,
1344
    ):
1345
        # Divide the weight matrix along the first dimension.
1346
1347
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
1348
1349
1350
1351
        self.input_size_per_partition = divide(input_size, self.tp_size)
        self.output_size_per_partition = output_size
        self.output_partition_sizes = [output_size]

1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
        super().__init__(
            input_size,
            output_size,
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
1362

1363
1364
1365
        self.input_is_parallel = input_is_parallel
        self.reduce_results = reduce_results

1366
        assert self.quant_method is not None
1367
1368
1369
        self.quant_method.create_weights(
            layer=self,
            input_size_per_partition=self.input_size_per_partition,
1370
            output_partition_sizes=self.output_partition_sizes,
1371
1372
1373
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
1374
            weight_loader=(
1375
1376
1377
1378
1379
                self.weight_loader_v2
                if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
                else self.weight_loader
            ),
        )
1380
        if not reduce_results and (bias and not skip_bias_add):
1381
1382
1383
1384
            raise ValueError(
                "When not reduce the results, adding bias to the "
                "results can lead to incorrect results"
            )
1385
1386

        if bias:
1387
1388
1389
1390
1391
1392
1393
1394
            self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype))
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
1395
1396
        else:
            self.register_parameter("bias", None)
1397
        self.update_param_tp_status()
1398
1399
1400

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        input_dim = getattr(param, "input_dim", None)
1401
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1402
1403
1404
1405
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
1406
1407
1408
1409
1410
1411
1412
1413
1414

        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
1415
1416
            weight_shape = list(loaded_weight.shape)
            if input_dim:
1417
                weight_shape[input_dim] = weight_shape[input_dim] // self.tp_size
1418
            param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
1419

1420
        param_data = param.data
1421
        if input_dim is not None and not is_sharded_weight:
1422
            shard_size = param_data.shape[input_dim]
1423
            start_idx = self.tp_rank * shard_size
1424
            loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)
1425

1426
1427
1428
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
1429
1430
            loaded_weight = loaded_weight.reshape(1)

1431
1432
1433
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

1434
    def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor):
1435
1436
1437
1438
1439
1440
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)

1441
1442
        param.load_row_parallel_weight(loaded_weight=loaded_weight)

1443
    def forward(
1444
1445
        self,
        input_,
1446
    ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
1447
1448
1449
1450
        if self.input_is_parallel:
            input_parallel = input_
        else:
            splitted_input = split_tensor_along_last_dim(
1451
1452
                input_, num_partitions=self.tp_size
            )
1453
            input_parallel = splitted_input[self.tp_rank].contiguous()
1454
1455

        # Matrix multiply.
1456
        assert self.quant_method is not None
1457
1458
1459
        # Only fuse bias add into GEMM for rank 0 (this ensures that
        # bias will not get added more than once in TP>1 case)
        bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
1460
1461
        output_parallel = self.quant_method.apply(self, input_parallel, bias_)

1462
        if self.reduce_results and self.tp_size > 1:
1463
            output = tensor_model_parallel_all_reduce(output_parallel)
1464
        else:
1465
1466
            output = output_parallel

1467
1468
        if not self.return_bias:
            return output
1469
        output_bias = self.bias if self.skip_bias_add else None
1470
        return output, output_bias
1471
1472

    def extra_repr(self) -> str:
1473
        s = f"in_features={self.input_size_per_partition}"
1474
1475
1476
1477
1478
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={self.tp_size}"
        s += f", reduce_results={self.reduce_results}"
        return s