linear.py 56.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
from abc import abstractmethod
6
from typing import Any
7
8

import torch
9
from torch.nn.parameter import Parameter, UninitializedParameter
10

11
12
13
14
15
16
17
18
from vllm.distributed import (
    divide,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
    split_tensor_along_last_dim,
    tensor_model_parallel_all_gather,
    tensor_model_parallel_all_reduce,
)
19
from vllm.logger import init_logger
20
from vllm.model_executor.custom_op import CustomOp
21
from vllm.model_executor.layers.quantization.base_config import (
22
23
24
    QuantizationConfig,
    QuantizeMethodBase,
)
25
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
26
27
28
29
30
31
32
33
34
from vllm.model_executor.parameter import (
    BasevLLMParameter,
    BlockQuantScaleParameter,
    ModelWeightParameter,
    PackedColumnParameter,
    PackedvLLMParameter,
    PerTensorScaleParameter,
    RowvLLMParameter,
)
35
from vllm.model_executor.utils import set_weight_attrs
36
from vllm.platforms import current_platform
37
from vllm.utils import GiB_bytes
38
39
40

logger = init_logger(__name__)

41
WEIGHT_LOADER_V2_SUPPORTED = [
42
    "UnquantizedLinearMethod",
43
    "CompressedTensorsLinearMethod",
44
    "CompressedTensorsLinearTransformMethod",
45
46
    "BitBLASLinearMethod",
    "GPTQBitBLASLinearMethod",
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
    "AWQMarlinLinearMethod",
    "AWQLinearMethod",
    "GPTQMarlinLinearMethod",
    "Fp8LinearMethod",
    "MarlinLinearMethod",
    "GPTQMarlin24LinearMethod",
    "TPUInt8LinearMethod",
    "GPTQLinearMethod",
    "FBGEMMFp8LinearMethod",
    "ModelOptFp8LinearMethod",
    "IPEXAWQLinearMethod",
    "IPEXGPTQLinearMethod",
    "HQQMarlinMethod",
    "QuarkLinearMethod",
    "ModelOptNvFp4LinearMethod",
62
    "PetitNvFp4LinearMethod",
63
]
64

65

66
67
68
def adjust_bitblas_shard(param, shard_size, shard_offset):
    bitblas_tile_size = getattr(param, "bitblas_tile_size", None)
    if bitblas_tile_size is not None:
69
        return (shard_size // bitblas_tile_size, shard_offset // bitblas_tile_size)
70
71
72
73

    return shard_size, shard_offset


74
75
76
77
78
79
80
81
def adjust_marlin_shard(param, shard_size, shard_offset):
    marlin_tile_size = getattr(param, "marlin_tile_size", None)
    if marlin_tile_size is None:
        return shard_size, shard_offset

    return shard_size * marlin_tile_size, shard_offset * marlin_tile_size


82
83
84
def adjust_bitsandbytes_4bit_shard(
    param: Parameter, shard_offsets: dict[str, tuple[int, int]], loaded_shard_id: str
) -> tuple[int, int]:
85
86
    """Adjust the quantization offsets and sizes for BitsAndBytes sharding."""

87
88
    total, _ = shard_offsets["total"]
    orig_offset, orig_size = shard_offsets[loaded_shard_id]
89
90
91
92
93
94
95
96

    quantized_total = param.data.shape[0]
    quantized_offset = orig_offset * quantized_total // total
    quantized_size = orig_size * quantized_total // total

    return quantized_size, quantized_offset


97
98
99
def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
    """For fused modules (QKV and MLP) we have an array of length
    N that holds 1 scale for each "logical" matrix. So the param
100
101
    is an array of length N. The loaded_weight corresponds to
    one of the shards on disk. Here, we slice the param based on
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
    the shard_id for loading.
    """
    qkv_idxs = {"q": 0, "k": 1, "v": 2}

    if isinstance(shard_id, str):
        shard_id = qkv_idxs[shard_id]
    elif not isinstance(shard_id, int):
        raise ValueError(f"Unknown Shard Id {shard_id}")

    # AutoFP8 scales do not have a shape
    # compressed-tensors scales do have a shape
    if len(loaded_weight.shape) != 0:
        assert loaded_weight.shape[0] == 1
        loaded_weight = loaded_weight[0]

    return param[shard_id], loaded_weight


120
121
122
123
124
125
126
127
# TODO(Isotr0py): We might need a more flexible structure to handle
# bitsandbytes shard offsets.
def left_shift_bitsandbytes_4bit_shard(bnb_weight_attrs: dict[str, Any]):
    """
    Separate the BitsAndBytes 4-bit shard.

    For example, given bnb weight attributes as below:
    {
128
        'bnb_shard_offsets': array([0, 4, 8, 16]),
129
130
131
132
133
        'bnb_quant_state': {0: ..., 1: ..., 2: ...},
    }

    The function will return:
    {
134
        'bnb_shard_offsets': array([0, 4]),
135
136
137
138
139
140
141
142
143
144
145
146
147
148
        'bnb_quant_state': {0: ...},
    }
    and
    {
        'bnb_shard_offsets': array([0, 4, 12]),
        'bnb_quant_state': {0: ..., 1: ...},
    }
    """
    shard_offsets = bnb_weight_attrs["bnb_shard_offsets"]
    offset_l = shard_offsets[:2]
    offset_r = shard_offsets[1:] - shard_offsets[1]
    quant_state_l = {0: bnb_weight_attrs["bnb_quant_state"][0]}
    quant_state_r = {
        i - 1: bnb_weight_attrs["bnb_quant_state"][i]
149
        for i in range(1, len(shard_offsets) - 1)
150
151
152
153
154
155
    }
    left = dict(bnb_shard_offsets=offset_l, bnb_quant_state=quant_state_l)
    right = dict(bnb_shard_offsets=offset_r, bnb_quant_state=quant_state_r)
    return left, right


156
class LinearMethodBase(QuantizeMethodBase):
157
158
159
    """Base class for different (maybe quantized) linear methods."""

    @abstractmethod
160
161
162
163
164
165
166
167
168
169
170
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        """Create weights for a linear layer.
171
           The weights will be set as attributes of the layer.
172

173
174
175
        Args:
            layer: The layer that is using the LinearMethodBase factory.
            input_size_per_partition: Size of the weight input dim on rank X.
176
            output_partition_sizes: Sizes of the output dim of each logical
177
178
179
180
181
182
                weight on rank X. E.g., output_partition_sizes for QKVLinear
                is a list contains the width of Wq, Wk, Wv on rank X.
            input_size: Size of the input dim of the weight across all ranks.
            output_size: Size of the output dim of the weight across all ranks.
            params_dtype: Datatype of the parameters.
        """
183
184
185
        raise NotImplementedError

    @abstractmethod
186
187
188
189
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
190
        bias: torch.Tensor | None = None,
191
    ) -> torch.Tensor:
192
193
        """Apply the weights in layer to the input tensor.
        Expects create_weights to have been called before on the layer."""
194
195
196
197
        raise NotImplementedError


class UnquantizedLinearMethod(LinearMethodBase):
198
    """Linear method without quantization."""
199

200
201
202
203
204
205
206
207
208
209
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
210
211
212
213
214
        # This method creates unquantized linear weights.
        # The weights are not quantized, and they are not sharded.
        # The amount of memory allocated for the weights is
        # sum(output_partition_sizes) * input_size_per_partition.
        try:
215
            weight_loader = extra_weight_attrs.pop("weight_loader")
216
217
218
219
220
221
222
223
224
225
            weight = ModelWeightParameter(
                data=torch.empty(
                    sum(output_partition_sizes),
                    input_size_per_partition,
                    dtype=params_dtype,
                ),
                input_dim=1,
                output_dim=0,
                weight_loader=weight_loader,
            )
226
227
228
229
        except torch.cuda.OutOfMemoryError as e:
            logger.error("Failed to create unquantized linear weights: %s", e)
            if torch.cuda.is_available():
                logger.debug("CUDA device: %s", torch.cuda.current_device())
230
231
232
233
234
235
                logger.debug(
                    "Allocated: %.2f GiB", torch.cuda.memory_allocated() / GiB_bytes
                )
                logger.debug(
                    "Reserved: %.2f GiB", torch.cuda.memory_reserved() / GiB_bytes
                )
236
237
238
            raise RuntimeError(
                "Failed to create unquantized linear weights. "
                "This may be caused by insufficient memory to allocate "
239
240
                "the weight."
            ) from e
241

242
243
        layer.register_parameter("weight", weight)
        set_weight_attrs(weight, extra_weight_attrs)
244

245
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
246
        if current_platform.is_cpu():
247
            from vllm.model_executor.layers.utils import dispatch_cpu_unquantized_gemm
248

249
            dispatch_cpu_unquantized_gemm(layer, remove_weight=True)
250

251
252
253
254
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
255
        bias: torch.Tensor | None = None,
256
    ) -> torch.Tensor:
257
        return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
258
259


260
class LinearBase(CustomOp):
261
    """Base linear layer.
262
263
264
265
266
267

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
268
        quant_config: Quantization configure.
269
        prefix: Prefix for parameter names.
270
        return_bias: If true, return bias together with outputs in forward pass.
271
        disable_tp: If true, tensor parallelism will be disabled for this layer.
272
273
274
275
276
277
278
    """

    def __init__(
        self,
        input_size: int,
        output_size: int,
        skip_bias_add: bool = False,
279
280
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
281
        prefix: str = "",
282
283
        *,
        return_bias: bool = True,
284
        disable_tp: bool = False,
285
286
287
288
289
290
291
292
293
294
    ):
        super().__init__()

        # Keep input parameters
        self.input_size = input_size
        self.output_size = output_size
        self.skip_bias_add = skip_bias_add
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype
295
296
        self.quant_config = quant_config
        self.prefix = prefix
297
        if quant_config is None:
298
            self.quant_method: QuantizeMethodBase | None = UnquantizedLinearMethod()
299
        else:
300
            self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
301
        self.return_bias = return_bias
302
        self.disable_tp = disable_tp
303
304
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
305

306
    def update_param_tp_status(self):
307
308
309
310
        for param in self.parameters():
            if isinstance(param, BasevLLMParameter):
                param.tp_rank = self.tp_rank
                param.tp_size = self.tp_size
311
312


313
@CustomOp.register("replicated_linear")
314
315
316
317
318
319
320
321
322
323
class ReplicatedLinear(LinearBase):
    """Replicated linear layer.

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        bias: If true, add bias.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
        quant_config: Quantization configure.
324
325
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
326
        return_bias: If true, return bias together with outputs in forward pass.
327
        disable_tp: Take no effect for replicated linear layers.
328
329
    """

330
331
332
333
334
335
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        skip_bias_add: bool = False,
336
337
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
338
339
340
        prefix: str = "",
        *,
        return_bias: bool = True,
341
        disable_tp: bool = False,
342
    ):
343
344
345
346
347
348
        # If MergedReplicatedLinear, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = self.output_sizes
        else:
            self.output_partition_sizes = [output_size]

349
350
351
352
353
354
355
356
357
358
        super().__init__(
            input_size,
            output_size,
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
359

360
361
        # All the linear layer supports quant method.
        assert self.quant_method is not None
362
363
364
365
366
367
368
369
370
        self.quant_method.create_weights(
            self,
            self.input_size,
            self.output_partition_sizes,
            self.input_size,
            self.output_size,
            self.params_dtype,
            weight_loader=self.weight_loader,
        )
371

372
373
        if bias:
            self.bias = Parameter(
374
375
376
377
378
379
380
381
382
                torch.empty(self.output_size, dtype=self.params_dtype)
            )
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
383
384
385
        else:
            self.register_parameter("bias", None)

386
387
388
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        # If the weight on disk does not have a shape, give it one
        # (such scales for AutoFp8).
389
390
391
392
393
394
395
396
397
398
399
        # Special case for GGUF

        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
            param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)

400
401
402
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)

403
404
        assert param.size() == loaded_weight.size(), (
            f"Tried to load weights of size {loaded_weight.size()}"
405
406
            f"to a parameter of size {param.size()}"
        )
407
408
        param.data.copy_(loaded_weight)

409
    def forward(
410
411
        self,
        x: torch.Tensor,
412
    ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
413
        bias = self.bias if not self.skip_bias_add else None
414
        assert self.quant_method is not None
415

416
        output = self.quant_method.apply(self, x, bias)
417
        output_bias = self.bias if self.skip_bias_add else None
418

419
420
        if not self.return_bias:
            return output
421
422
        return output, output_bias

423
424
425
426
427
428
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        return s

429

430
@CustomOp.register("column_parallel_linear")
431
class ColumnParallelLinear(LinearBase):
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
    """Linear layer with column parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its second dimension as A = [A_1, ..., A_p].

    Args:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y_i = XA_i
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
448
        quant_config: Quantization configure.
James Fleming's avatar
James Fleming committed
449
450
        output_sizes: list of output sizes packed into one output, like for QKV
                       the list would be size 3.
451
        prefix: The name of the layer in the state dict, including all parents
452
                        (e.g. model.layers.0.qkv_proj)
453
454
        return_bias: If true, return bias together with outputs in forward pass.
        disable_tp: If true, weights matrix won't be sharded through tp rank.
455
456
    """

457
458
459
460
461
462
463
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
464
465
466
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
        output_sizes: list[int] | None = None,
467
468
469
        prefix: str = "",
        *,
        return_bias: bool = True,
470
        disable_tp: bool = False,
471
    ):
472
        # Divide the weight matrix along the last dimension.
473
474
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
475
476
        self.input_size_per_partition = input_size
        self.output_size_per_partition = divide(output_size, self.tp_size)
477
478
479
480
        self.output_partition_sizes = [self.output_size_per_partition]
        # If QKV or MergedColumn, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = [
481
                divide(output_size, self.tp_size) for output_size in self.output_sizes
482
483
            ]

484
485
486
487
488
489
490
491
492
493
        super().__init__(
            input_size,
            output_size,
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
494
495
496

        self.gather_output = gather_output

James Fleming's avatar
James Fleming committed
497
498
        if output_sizes is None:
            output_sizes = [output_size]
499

500
        assert self.quant_method is not None
501
502
        self.quant_method.create_weights(
            layer=self,
503
            input_size_per_partition=self.input_size_per_partition,
504
505
506
507
            output_partition_sizes=self.output_partition_sizes,
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
508
            weight_loader=(
509
510
511
512
513
                self.weight_loader_v2
                if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
                else self.weight_loader
            ),
        )
514
515
        if bias:
            self.bias = Parameter(
516
517
518
519
520
521
522
523
524
                torch.empty(self.output_size_per_partition, dtype=params_dtype)
            )
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
525
526
        else:
            self.register_parameter("bias", None)
527
        self.update_param_tp_status()
528
529
530

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        output_dim = getattr(param, "output_dim", None)
531

532
533
534
535
536
537
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

538
539
540
541
542
543
544
545
        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
546
547
            final_shape = list(loaded_weight.shape)
            if output_dim is not None:
548
                assert final_shape[output_dim] % self.tp_size == 0
549
                final_shape[output_dim] = final_shape[output_dim] // self.tp_size
550
            param.materialize(final_shape, dtype=loaded_weight.dtype)
551

552
        param_data = param.data
553
        if output_dim is not None and not is_sharded_weight:
554
            shard_size = param_data.shape[output_dim]
555
            start_idx = self.tp_rank * shard_size
556
            loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
557
558
559
560
561

        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)
562

563
564
565
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

566
    def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor):
567
568
569
570
571
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)
572
573
        param.load_column_parallel_weight(loaded_weight=loaded_weight)

574
    def forward(
575
576
        self,
        input_,
577
    ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
578
579
580
        bias = self.bias if not self.skip_bias_add else None

        # Matrix multiply.
581
        assert self.quant_method is not None
582
        output_parallel = self.quant_method.apply(self, input_, bias)
583

584
        if self.gather_output and self.tp_size > 1:
585
586
587
588
589
            # All-gather across the partitions.
            output = tensor_model_parallel_all_gather(output_parallel)
        else:
            output = output_parallel
        output_bias = self.bias if self.skip_bias_add else None
590
591
        if not self.return_bias:
            return output
592
593
        return output, output_bias

594
595
596
597
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size_per_partition}"
        s += f", bias={self.bias is not None}"
598
        s += f", tp_size={self.tp_size}"
599
600
601
        s += f", gather_output={self.gather_output}"
        return s

602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620

class MergedColumnParallelLinear(ColumnParallelLinear):
    """Packed linear layers with column parallelism.

    Similar to ColumnParallelLinear, but the weight matrix is concatenated
    along the output dimension. When the weight matrix is loaded, the
    different partitions are sharded separately.

    Args:
        input_size: input dimension of the linear layer.
        output_sizes: list of output dimensions of the linear layer.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make the output
                       available to all GPUs, otherwise, every GPU will have
                       its own output.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
621
        quant_config: Quantization configure.
622
623
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
624
        return_bias: If true, return bias together with outputs in forward pass.
625
626
        disable_tp: If true, all weights matrix won't be sharded, this layer
                    will be treated as a "Replicated" MergedLinear.
627
628
    """

629
630
631
632
633
634
635
    def __init__(
        self,
        input_size: int,
        output_sizes: list[int],
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
636
637
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
638
639
640
        prefix: str = "",
        *,
        return_bias: bool = True,
641
        disable_tp: bool = False,
642
    ):
643
        self.output_sizes = output_sizes
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0

        assert all(output_size % self.tp_size == 0 for output_size in output_sizes)
        super().__init__(
            input_size=input_size,
            output_size=sum(output_sizes),
            bias=bias,
            gather_output=gather_output,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            quant_config=quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
James Fleming's avatar
James Fleming committed
660

661
662
663
664
    def weight_loader(
        self,
        param: Parameter,
        loaded_weight: torch.Tensor,
665
        loaded_shard_id: int | None = None,
666
    ):
667
668
669
670
671
        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
672
673
674
675
676
            if loaded_shard_id is not None:
                param.data[loaded_shard_id].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
                param.shard_weight_type = {
677
                    i: loaded_weight.item() for i, _ in enumerate(self.output_sizes)
678
                }
679
680
            return

681
682
        if is_gguf_weight:
            output_dim = getattr(param, "output_dim", None)
683
684
            shard_size = loaded_weight.size(output_dim) // self.tp_size
            start_idx = self.tp_rank * shard_size
685

686
            if loaded_shard_id is not None:
687
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
688
689
690
691
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
692

693
694
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
695
696
        # Special case for per-tensor scale to load scalar into fused array.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
697

698
        if loaded_shard_id is None:
699
700
            # Loaded weight is already fused on disk (mlp).
            # (e.g., Phi-3's gate_up_proj).
701
            if output_dim is None:
702
                if needs_scalar_to_array:
703
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
704
705
                        param_data, loaded_weight, 0
                    )
706

707
708
709
710
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            current_shard_offset = 0
711
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
712
            shard_offsets: list[tuple[int, int, int]] = []
713
714
715
716
717
            for i, output_size in enumerate(self.output_sizes):
                shard_offsets.append((i, current_shard_offset, output_size))
                current_shard_offset += output_size
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
718
                # Special case for Quantization.
719
720
721
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
722
723
                    shard_size = shard_size // param.packed_factor
                    shard_offset = shard_offset // param.packed_factor
724
                    # Special case for Marlin.
725
                    shard_size, shard_offset = adjust_marlin_shard(
726
727
                        param, shard_size, shard_offset
                    )
728

729
                shard_size, shard_offset = adjust_bitblas_shard(
730
731
                    param, shard_size, shard_offset
                )
732

733
                if use_bitsandbytes_4bit:
734
735
736
737
738
739
740
                    index = list(itertools.accumulate([0] + self.output_sizes))
                    orig_offsets = {
                        str(i): (index[i], size)
                        for i, size in enumerate(self.output_sizes)
                    }
                    orig_offsets["total"] = (self.output_size, 0)
                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
741
742
                        param, orig_offsets, str(shard_id)
                    )
743

744
                loaded_weight_shard = loaded_weight.narrow(
745
746
                    output_dim, shard_offset, shard_size
                )
747
748
749
750
751
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id < len(self.output_sizes)
        if output_dim is not None:
752
            shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
753
            shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
754
            # Special case for quantization.
755
756
757
758
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
759
760
                shard_size = shard_size // param.packed_factor
                shard_offset = shard_offset // param.packed_factor
761
                # Special case for Marlin.
762
                shard_size, shard_offset = adjust_marlin_shard(
763
764
                    param, shard_size, shard_offset
                )
765
            shard_size, shard_offset = adjust_bitblas_shard(
766
767
                param, shard_size, shard_offset
            )
768

769
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
770
771
772
773
774
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

775
            if use_bitsandbytes_4bit:
776
                shard_size = loaded_weight.shape[output_dim]
777
                shard_offset = loaded_weight.shape[output_dim] * loaded_shard_id
778

779
            param_data = param_data.narrow(output_dim, shard_offset, shard_size)
780
            start_idx = self.tp_rank * shard_size
781
            if not is_sharded_weight:
782
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
783
784
785
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
786
787
                param_data, loaded_weight, loaded_shard_id
            )
788

789
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
790
791
792
793
794
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "MergedColumnParallelLinear, assume the weight is "
795
796
                    "the same for all partitions."
                )
797

798
799
800
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

801
802
803
    def _load_fused_module_from_checkpoint(
        self, param: BasevLLMParameter, loaded_weight: torch.Tensor
    ):
804
805
806
        """
        Handle special case for models where MLP layers are already
        fused on disk. In this case, we have no shard id. This function
807
        determines the shard id by splitting these layers and then calls
808
809
810
811
812
813
814
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """

        current_shard_offset = 0
815
        shard_offsets: list[tuple[int, int, int]] = []
816
817
818
819
820
821
822
823
        for i, output_size in enumerate(self.output_sizes):
            shard_offsets.append((i, current_shard_offset, output_size))
            current_shard_offset += output_size

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
824
825
826
827
828
829
830
831
832
833
834
            if (
                isinstance(param, (PackedColumnParameter, PackedvLLMParameter))
                and param.packed_dim == param.output_dim
            ):
                shard_size, shard_offset = param.adjust_shard_indexes_for_packing(
                    shard_size=shard_size, shard_offset=shard_offset
                )

            loaded_weight_shard = loaded_weight.narrow(
                param.output_dim, shard_offset, shard_size
            )
835
836
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

837
838
839
840
    def weight_loader_v2(
        self,
        param: BasevLLMParameter,
        loaded_weight: torch.Tensor,
841
        loaded_shard_id: int | None = None,
842
    ):
843
        if loaded_shard_id is None:
844
            if isinstance(param, PerTensorScaleParameter):
845
                param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0)
846
                return
847
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
848
                param.load_merged_column_weight(loaded_weight=loaded_weight)
849
                return
850
            # TODO: @dsikka - move to parameter.py
851
852
853
854
855
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id < len(self.output_sizes)

856
857
        if isinstance(param, BlockQuantScaleParameter):
            assert self.quant_method is not None
858
859
860
            # Assume the weight block size has been set by quant method
            assert hasattr(self, "weight_block_size")
            weight_block_size = self.weight_block_size
861
862
863
            assert weight_block_size is not None
            block_n, _ = weight_block_size[0], weight_block_size[1]
            shard_offset = (
864
865
866
867
868
869
870
                (sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) // block_n
            ) // self.tp_size
            shard_size = (
                (self.output_sizes[loaded_shard_id] + block_n - 1)
                // block_n
                // self.tp_size
            )
871
        else:
872
            shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
873
            shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
874

875
876
877
878
879
880
881
        param.load_merged_column_weight(
            loaded_weight=loaded_weight,
            shard_id=loaded_shard_id,
            shard_offset=shard_offset,
            shard_size=shard_size,
            tp_rank=self.tp_rank,
        )
882

883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904

class QKVParallelLinear(ColumnParallelLinear):
    """Linear layers for the attention's QKV transformation.

    Linear layers for the linear transformation of the query, key, and value
    vectors in the attention layer. The weight matrix is concatenated along
    the output dimension. The layer is parallelized along the head dimension.
    When the number of key/value heads is smaller than the number of query
    heads (e.g., multi-query/grouped-query attention), the key/value head may
    be replicated while the query heads are partitioned.

    Args:
        hidden_size: input hidden state size of the transformer.
        head_size: size of each attention head.
        total_num_heads: total number of attention query heads.
        total_num_kv_heads: total number of attention key/value heads. If
                            None, assume total_num_kv_heads = total_num_heads.
        bias: If true, add bias.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
905
        quant_config: Quantization configure.
906
907
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
908
        return_bias: If true, return bias together with outputs in forward pass.
909
        disable_tp: If true, weights matrix won't be sharded through tp rank.
910
911
    """

912
913
914
915
916
    def __init__(
        self,
        hidden_size: int,
        head_size: int,
        total_num_heads: int,
917
        total_num_kv_heads: int | None = None,
918
919
        bias: bool = True,
        skip_bias_add: bool = False,
920
921
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
922
923
924
        prefix: str = "",
        *,
        return_bias: bool = True,
925
        disable_tp: bool = False,
926
    ):
927
928
929
930
931
932
933
        self.hidden_size = hidden_size
        self.head_size = head_size
        self.total_num_heads = total_num_heads
        if total_num_kv_heads is None:
            total_num_kv_heads = total_num_heads
        self.total_num_kv_heads = total_num_kv_heads
        # Divide the weight matrix along the last dimension.
934
        tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
935
936
937
        self.num_heads = divide(self.total_num_heads, tp_size)
        if tp_size >= self.total_num_kv_heads:
            self.num_kv_heads = 1
938
            self.num_kv_head_replicas = divide(tp_size, self.total_num_kv_heads)
939
940
941
942
        else:
            self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
            self.num_kv_head_replicas = 1
        input_size = self.hidden_size
943
944
945
        output_size = (
            (self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size
        )
946
947
948
        self.output_sizes = [
            self.num_heads * self.head_size * tp_size,  # q_proj
            self.num_kv_heads * self.head_size * tp_size,  # k_proj
949
            self.num_kv_heads * self.head_size * tp_size,  # v_proj
James Fleming's avatar
James Fleming committed
950
951
        ]

952
953
954
955
956
957
958
959
960
961
962
963
        super().__init__(
            input_size=input_size,
            output_size=output_size,
            bias=bias,
            gather_output=False,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            quant_config=quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
964

965
966
967
968
969
    def _get_shard_offset_mapping(self, loaded_shard_id: str):
        shard_offset_mapping = {
            "q": 0,
            "k": self.num_heads * self.head_size,
            "v": (self.num_heads + self.num_kv_heads) * self.head_size,
970
            "total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size,
971
972
973
974
975
976
977
978
979
980
981
        }
        return shard_offset_mapping.get(loaded_shard_id)

    def _get_shard_size_mapping(self, loaded_shard_id: str):
        shard_size_mapping = {
            "q": self.num_heads * self.head_size,
            "k": self.num_kv_heads * self.head_size,
            "v": self.num_kv_heads * self.head_size,
        }
        return shard_size_mapping.get(loaded_shard_id)

982
983
984
    def _load_fused_module_from_checkpoint(
        self, param: BasevLLMParameter, loaded_weight: torch.Tensor
    ):
985
        """
986
        Handle special case for models where QKV layers are already
987
        fused on disk. In this case, we have no shard id. This function
988
        determines the shard id by splitting these layers and then calls
989
990
991
992
993
994
995
996
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """
        shard_offsets = [
            # (shard_id, shard_offset, shard_size)
            ("q", 0, self.total_num_heads * self.head_size),
997
998
999
1000
1001
1002
1003
1004
1005
1006
            (
                "k",
                self.total_num_heads * self.head_size,
                self.total_num_kv_heads * self.head_size,
            ),
            (
                "v",
                (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
                self.total_num_kv_heads * self.head_size,
            ),
1007
1008
1009
1010
1011
1012
        ]

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
            if (
                isinstance(param, (PackedColumnParameter, PackedvLLMParameter))
                and param.packed_dim == param.output_dim
            ):
                shard_size, shard_offset = param.adjust_shard_indexes_for_packing(
                    shard_size=shard_size, shard_offset=shard_offset
                )

            loaded_weight_shard = loaded_weight.narrow(
                param.output_dim, shard_offset, shard_size
            )
1024
1025
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

1026
1027
1028
1029
    def weight_loader_v2(
        self,
        param: BasevLLMParameter,
        loaded_weight: torch.Tensor,
1030
        loaded_shard_id: str | None = None,
1031
    ):
1032
        if loaded_shard_id is None:  # special case for certain models
1033
            if isinstance(param, PerTensorScaleParameter):
1034
1035
1036
                param.load_qkv_weight(
                    loaded_weight=loaded_weight, shard_id=0, tp_rank=self.tp_rank
                )
1037
                return
1038
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
1039
                param.load_qkv_weight(loaded_weight=loaded_weight, tp_rank=self.tp_rank)
1040
                return
1041
            # TODO: @dsikka - move to parameter.py
1042
1043
1044
1045
1046
1047
1048
1049
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id in ["q", "k", "v"]

        shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
        shard_size = self._get_shard_size_mapping(loaded_shard_id)

1050
1051
1052
        # Note(simon): This is needed for Qwen3's fp8 quantization.
        if isinstance(param, BlockQuantScaleParameter):
            assert self.quant_method is not None
1053
1054
1055
1056
            # Assume the weight block size has been set by quant method
            assert hasattr(self, "weight_block_size")
            weight_block_size = self.weight_block_size
            assert weight_block_size is not None
1057
1058
1059
1060
            block_n, _ = weight_block_size[0], weight_block_size[1]
            shard_offset = (shard_offset + block_n - 1) // block_n
            shard_size = (shard_size + block_n - 1) // block_n

1061
1062
1063
1064
1065
1066
1067
1068
        param.load_qkv_weight(
            loaded_weight=loaded_weight,
            num_heads=self.num_kv_head_replicas,
            shard_id=loaded_shard_id,
            shard_offset=shard_offset,
            shard_size=shard_size,
            tp_rank=self.tp_rank,
        )
1069

1070
1071
1072
1073
    def weight_loader(
        self,
        param: Parameter,
        loaded_weight: torch.Tensor,
1074
        loaded_shard_id: str | None = None,
1075
    ):
1076
1077
1078
1079
        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
1080
        if is_gguf_weight_type:
1081
            idx_map = {"q": 0, "k": 1, "v": 2}
1082
1083
1084
1085
            if loaded_shard_id is not None:
                param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
1086
                param.shard_weight_type = {k: loaded_weight.item() for k in idx_map}
1087
1088
            return

1089
1090
        if is_gguf_weight:
            output_dim = getattr(param, "output_dim", None)
1091
1092
            shard_size = loaded_weight.size(output_dim) // self.tp_size
            start_idx = self.tp_rank * shard_size
1093

1094
            if loaded_shard_id is not None:
1095
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
1096
1097
1098
1099
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
1100

1101
1102
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
1103

1104
1105
        # Special case for per-tensor scales in fused case.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
1106

1107
        if loaded_shard_id is None:
1108
1109
            # Loaded weight is already fused on disk (qkv).
            # (e.g., Phi-3's qkv_proj).
1110
            if output_dim is None:
1111
                if needs_scalar_to_array:
1112
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
1113
1114
                        param_data, loaded_weight, 0
                    )
1115

1116
1117
1118
1119
1120
1121
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            shard_offsets = [
                # (shard_id, shard_offset, shard_size)
                ("q", 0, self.total_num_heads * self.head_size),
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
                (
                    "k",
                    self.total_num_heads * self.head_size,
                    self.total_num_kv_heads * self.head_size,
                ),
                (
                    "v",
                    (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
                    self.total_num_kv_heads * self.head_size,
                ),
1132
            ]
1133
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1134

1135
1136
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
1137
                # Special case for Quantized Weights.
1138
1139
1140
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
1141
1142
                    shard_size = shard_size // param.packed_factor
                    shard_offset = shard_offset // param.packed_factor
1143

1144
                    # Special case for Marlin.
1145
                    shard_size, shard_offset = adjust_marlin_shard(
1146
1147
                        param, shard_size, shard_offset
                    )
1148

1149
1150
1151
                if use_bitsandbytes_4bit:
                    orig_qkv_offsets = {
                        "q": (0, self.total_num_heads * self.head_size),
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
                        "k": (
                            self.total_num_heads * self.head_size,
                            self.total_num_kv_heads * self.head_size,
                        ),
                        "v": (
                            (self.total_num_heads + self.total_num_kv_heads)
                            * self.head_size,
                            self.total_num_kv_heads * self.head_size,
                        ),
                        "total": (
                            (self.total_num_heads + 2 * self.total_num_kv_heads)
                            * self.head_size,
                            0,
                        ),
1166
1167
1168
                    }

                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
1169
1170
                        param, orig_qkv_offsets, shard_id
                    )
1171

1172
                loaded_weight_shard = loaded_weight.narrow(
1173
1174
                    output_dim, shard_offset, shard_size
                )
1175
1176
1177
1178
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id in ["q", "k", "v"]
1179
1180

        # If output dim is defined, use the default loading process.
1181
1182
1183
1184
1185
1186
1187
1188
        if output_dim is not None:
            if loaded_shard_id == "q":
                shard_offset = 0
                shard_size = self.num_heads * self.head_size
            elif loaded_shard_id == "k":
                shard_offset = self.num_heads * self.head_size
                shard_size = self.num_kv_heads * self.head_size
            elif loaded_shard_id == "v":
1189
                shard_offset = (self.num_heads + self.num_kv_heads) * self.head_size
1190
                shard_size = self.num_kv_heads * self.head_size
1191
            # Special case for Quantized Weights.
1192
1193
1194
1195
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
1196
1197
                shard_size = shard_size // param.packed_factor
                shard_offset = shard_offset // param.packed_factor
1198

1199
                # Special case for Marlin.
1200
                shard_size, shard_offset = adjust_marlin_shard(
1201
1202
                    param, shard_size, shard_offset
                )
1203

1204
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1205
1206
1207
1208
1209
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

1210
            if use_bitsandbytes_4bit:
1211
1212
                orig_qkv_offsets = {
                    "q": (0, self.num_heads * self.head_size),
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
                    "k": (
                        self.num_heads * self.head_size,
                        self.num_kv_heads * self.head_size,
                    ),
                    "v": (
                        (self.num_heads + self.num_kv_heads) * self.head_size,
                        self.num_kv_heads * self.head_size,
                    ),
                    "total": (
                        (self.num_heads + 2 * self.num_kv_heads) * self.head_size,
                        0,
                    ),
1225
                }
1226
                shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
1227
1228
                    param, orig_qkv_offsets, loaded_shard_id
                )
1229

1230
            param_data = param_data.narrow(output_dim, shard_offset, shard_size)
1231
            if loaded_shard_id == "q":
1232
                shard_rank = self.tp_rank
1233
            else:
1234
1235
                shard_rank = self.tp_rank // self.num_kv_head_replicas
            start_idx = shard_rank * shard_size
1236

1237
            if not is_sharded_weight:
1238
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
1239

1240
1241
1242
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
1243
1244
                param_data, loaded_weight, loaded_shard_id
            )
1245
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
1246
1247
1248
1249
1250
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "QKVParallelLinear, assume the weight is the same "
1251
1252
                    "for all partitions."
                )
1253

1254
1255
1256
1257
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)


1258
@CustomOp.register("row_parallel_linear")
1259
class RowParallelLinear(LinearBase):
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
    """Linear layer with row parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its first dimension and X along its second dimension as:
               -   -
              | A_1 |
              | .   |
          A = | .   |        X = [X_1, ..., X_p]
              | .   |
              | A_p |
               -   -
    Arguments:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias. Note that bias is not parallelized.
        input_is_parallel: If true, we assume that the input is already
                           split across the GPUs and we do not split
                           again.
        skip_bias_add: This was added to enable performance optimization where
                       bias can be fused with other element-wise operations.
                       We skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
1282
1283
1284
        reduce_results: If true, call all-reduce on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y = X_iA_i
1285
        quant_config: Quantization configure.
1286
1287
1288
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.down_proj)
        return_bias: If true, return bias together with outputs in forward pass.
1289
        disable_tp: If true, weights matrix won't be sharded through tp rank.
1290
1291
    """

1292
1293
1294
1295
1296
1297
1298
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        input_is_parallel: bool = True,
        skip_bias_add: bool = False,
1299
        params_dtype: torch.dtype | None = None,
1300
        reduce_results: bool = True,
1301
        quant_config: QuantizationConfig | None = None,
1302
1303
1304
        prefix: str = "",
        *,
        return_bias: bool = True,
1305
        disable_tp: bool = False,
1306
    ):
1307
        # Divide the weight matrix along the first dimension.
1308
1309
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
1310
1311
1312
1313
        self.input_size_per_partition = divide(input_size, self.tp_size)
        self.output_size_per_partition = output_size
        self.output_partition_sizes = [output_size]

1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
        super().__init__(
            input_size,
            output_size,
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
1324

1325
1326
1327
        self.input_is_parallel = input_is_parallel
        self.reduce_results = reduce_results

1328
        assert self.quant_method is not None
1329
1330
1331
        self.quant_method.create_weights(
            layer=self,
            input_size_per_partition=self.input_size_per_partition,
1332
            output_partition_sizes=self.output_partition_sizes,
1333
1334
1335
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
1336
            weight_loader=(
1337
1338
1339
1340
1341
                self.weight_loader_v2
                if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
                else self.weight_loader
            ),
        )
1342
        if not reduce_results and (bias and not skip_bias_add):
1343
1344
1345
1346
            raise ValueError(
                "When not reduce the results, adding bias to the "
                "results can lead to incorrect results"
            )
1347
1348

        if bias:
1349
1350
1351
1352
1353
1354
1355
1356
            self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype))
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
1357
1358
        else:
            self.register_parameter("bias", None)
1359
        self.update_param_tp_status()
1360
1361
1362

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        input_dim = getattr(param, "input_dim", None)
1363
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1364
1365
1366
1367
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
1368
1369
1370
1371
1372
1373
1374
1375
1376

        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
1377
1378
            weight_shape = list(loaded_weight.shape)
            if input_dim:
1379
                weight_shape[input_dim] = weight_shape[input_dim] // self.tp_size
1380
            param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
1381

1382
        param_data = param.data
1383
        if input_dim is not None and not is_sharded_weight:
1384
            shard_size = param_data.shape[input_dim]
1385
            start_idx = self.tp_rank * shard_size
1386
            loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)
1387

1388
1389
1390
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
1391
1392
            loaded_weight = loaded_weight.reshape(1)

1393
1394
1395
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

1396
    def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor):
1397
1398
1399
1400
1401
1402
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)

1403
1404
        param.load_row_parallel_weight(loaded_weight=loaded_weight)

1405
    def forward(
1406
1407
        self,
        input_,
1408
    ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
1409
1410
1411
1412
        if self.input_is_parallel:
            input_parallel = input_
        else:
            splitted_input = split_tensor_along_last_dim(
1413
1414
                input_, num_partitions=self.tp_size
            )
1415
            input_parallel = splitted_input[self.tp_rank].contiguous()
1416
1417

        # Matrix multiply.
1418
        assert self.quant_method is not None
1419
1420
1421
        # Only fuse bias add into GEMM for rank 0 (this ensures that
        # bias will not get added more than once in TP>1 case)
        bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
1422
1423
        output_parallel = self.quant_method.apply(self, input_parallel, bias_)

1424
        if self.reduce_results and self.tp_size > 1:
1425
            output = tensor_model_parallel_all_reduce(output_parallel)
1426
        else:
1427
1428
1429
            output = output_parallel

        output_bias = self.bias if self.skip_bias_add else None
1430

1431
1432
        if not self.return_bias:
            return output
1433
        return output, output_bias
1434
1435

    def extra_repr(self) -> str:
1436
        s = f"in_features={self.input_size_per_partition}"
1437
1438
1439
1440
1441
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={self.tp_size}"
        s += f", reduce_results={self.reduce_results}"
        return s