linear.py 57.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
from abc import abstractmethod
6
from typing import Any
7
8

import torch
9
from torch.nn.parameter import Parameter, UninitializedParameter
10

11
12
13
14
15
16
17
18
from vllm.distributed import (
    divide,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
    split_tensor_along_last_dim,
    tensor_model_parallel_all_gather,
    tensor_model_parallel_all_reduce,
)
19
from vllm.logger import init_logger
20
from vllm.model_executor.custom_op import CustomOp
21
from vllm.model_executor.layers.quantization.base_config import (
22
23
24
    QuantizationConfig,
    QuantizeMethodBase,
)
25
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
26
27
28
29
30
31
32
33
34
from vllm.model_executor.parameter import (
    BasevLLMParameter,
    BlockQuantScaleParameter,
    ModelWeightParameter,
    PackedColumnParameter,
    PackedvLLMParameter,
    PerTensorScaleParameter,
    RowvLLMParameter,
)
35
from vllm.model_executor.utils import set_weight_attrs
36
from vllm.platforms import current_platform
37
38
39

logger = init_logger(__name__)

40
WEIGHT_LOADER_V2_SUPPORTED = [
41
    "UnquantizedLinearMethod",
42
    "CompressedTensorsLinearMethod",
43
    "CompressedTensorsLinearTransformMethod",
44
45
    "BitBLASLinearMethod",
    "GPTQBitBLASLinearMethod",
46
47
48
49
50
51
52
53
54
55
    "AWQMarlinLinearMethod",
    "AWQLinearMethod",
    "GPTQMarlinLinearMethod",
    "Fp8LinearMethod",
    "MarlinLinearMethod",
    "GPTQMarlin24LinearMethod",
    "TPUInt8LinearMethod",
    "GPTQLinearMethod",
    "FBGEMMFp8LinearMethod",
    "ModelOptFp8LinearMethod",
56
57
    "ModelOptFp8PcPtLinearMethod",
    "ModelOptFp8PbWoLinearMethod",
58
59
60
61
62
    "IPEXAWQLinearMethod",
    "IPEXGPTQLinearMethod",
    "HQQMarlinMethod",
    "QuarkLinearMethod",
    "ModelOptNvFp4LinearMethod",
63
    "PetitNvFp4LinearMethod",
64
]
65

66

67
68
69
def adjust_bitblas_shard(param, shard_size, shard_offset):
    bitblas_tile_size = getattr(param, "bitblas_tile_size", None)
    if bitblas_tile_size is not None:
70
        return (shard_size // bitblas_tile_size, shard_offset // bitblas_tile_size)
71
72
73
74

    return shard_size, shard_offset


75
76
77
78
79
80
81
82
def adjust_marlin_shard(param, shard_size, shard_offset):
    marlin_tile_size = getattr(param, "marlin_tile_size", None)
    if marlin_tile_size is None:
        return shard_size, shard_offset

    return shard_size * marlin_tile_size, shard_offset * marlin_tile_size


83
84
85
86
87
88
89
90
def adjust_block_scale_shard(weight_block_size, shard_size, shard_offset):
    assert weight_block_size is not None
    block_n = weight_block_size[0]
    shard_offset = (shard_offset + block_n - 1) // block_n
    shard_size = (shard_size + block_n - 1) // block_n
    return shard_size, shard_offset


91
92
93
def adjust_bitsandbytes_4bit_shard(
    param: Parameter, shard_offsets: dict[str, tuple[int, int]], loaded_shard_id: str
) -> tuple[int, int]:
94
95
    """Adjust the quantization offsets and sizes for BitsAndBytes sharding."""

96
97
    total, _ = shard_offsets["total"]
    orig_offset, orig_size = shard_offsets[loaded_shard_id]
98
99
100
101
102
103
104
105

    quantized_total = param.data.shape[0]
    quantized_offset = orig_offset * quantized_total // total
    quantized_size = orig_size * quantized_total // total

    return quantized_size, quantized_offset


106
107
108
def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
    """For fused modules (QKV and MLP) we have an array of length
    N that holds 1 scale for each "logical" matrix. So the param
109
110
    is an array of length N. The loaded_weight corresponds to
    one of the shards on disk. Here, we slice the param based on
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
    the shard_id for loading.
    """
    qkv_idxs = {"q": 0, "k": 1, "v": 2}

    if isinstance(shard_id, str):
        shard_id = qkv_idxs[shard_id]
    elif not isinstance(shard_id, int):
        raise ValueError(f"Unknown Shard Id {shard_id}")

    # AutoFP8 scales do not have a shape
    # compressed-tensors scales do have a shape
    if len(loaded_weight.shape) != 0:
        assert loaded_weight.shape[0] == 1
        loaded_weight = loaded_weight[0]

    return param[shard_id], loaded_weight


129
130
131
132
133
134
135
136
# TODO(Isotr0py): We might need a more flexible structure to handle
# bitsandbytes shard offsets.
def left_shift_bitsandbytes_4bit_shard(bnb_weight_attrs: dict[str, Any]):
    """
    Separate the BitsAndBytes 4-bit shard.

    For example, given bnb weight attributes as below:
    {
137
        'bnb_shard_offsets': array([0, 4, 8, 16]),
138
139
140
141
142
        'bnb_quant_state': {0: ..., 1: ..., 2: ...},
    }

    The function will return:
    {
143
        'bnb_shard_offsets': array([0, 4]),
144
145
146
147
148
149
150
151
152
153
154
155
156
157
        'bnb_quant_state': {0: ...},
    }
    and
    {
        'bnb_shard_offsets': array([0, 4, 12]),
        'bnb_quant_state': {0: ..., 1: ...},
    }
    """
    shard_offsets = bnb_weight_attrs["bnb_shard_offsets"]
    offset_l = shard_offsets[:2]
    offset_r = shard_offsets[1:] - shard_offsets[1]
    quant_state_l = {0: bnb_weight_attrs["bnb_quant_state"][0]}
    quant_state_r = {
        i - 1: bnb_weight_attrs["bnb_quant_state"][i]
158
        for i in range(1, len(shard_offsets) - 1)
159
160
161
162
163
164
    }
    left = dict(bnb_shard_offsets=offset_l, bnb_quant_state=quant_state_l)
    right = dict(bnb_shard_offsets=offset_r, bnb_quant_state=quant_state_r)
    return left, right


165
class LinearMethodBase(QuantizeMethodBase):
166
167
168
    """Base class for different (maybe quantized) linear methods."""

    @abstractmethod
169
170
171
172
173
174
175
176
177
178
179
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        """Create weights for a linear layer.
180
           The weights will be set as attributes of the layer.
181

182
183
184
        Args:
            layer: The layer that is using the LinearMethodBase factory.
            input_size_per_partition: Size of the weight input dim on rank X.
185
            output_partition_sizes: Sizes of the output dim of each logical
186
187
188
189
190
191
                weight on rank X. E.g., output_partition_sizes for QKVLinear
                is a list contains the width of Wq, Wk, Wv on rank X.
            input_size: Size of the input dim of the weight across all ranks.
            output_size: Size of the output dim of the weight across all ranks.
            params_dtype: Datatype of the parameters.
        """
192
193
194
        raise NotImplementedError

    @abstractmethod
195
196
197
198
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
199
        bias: torch.Tensor | None = None,
200
    ) -> torch.Tensor:
201
202
        """Apply the weights in layer to the input tensor.
        Expects create_weights to have been called before on the layer."""
203
204
205
206
        raise NotImplementedError


class UnquantizedLinearMethod(LinearMethodBase):
207
    """Linear method without quantization."""
208

209
210
211
212
213
214
215
216
217
218
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
219
220
221
222
        # This method creates unquantized linear weights.
        # The weights are not quantized, and they are not sharded.
        # The amount of memory allocated for the weights is
        # sum(output_partition_sizes) * input_size_per_partition.
223
224
225
226
227
228
229
230
231
232
233
        weight_loader = extra_weight_attrs.pop("weight_loader")
        weight = ModelWeightParameter(
            data=torch.empty(
                sum(output_partition_sizes),
                input_size_per_partition,
                dtype=params_dtype,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
234

235
236
        layer.register_parameter("weight", weight)
        set_weight_attrs(weight, extra_weight_attrs)
237

238
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
239
        if current_platform.is_cpu():
240
            from vllm.model_executor.layers.utils import dispatch_cpu_unquantized_gemm
241

242
            dispatch_cpu_unquantized_gemm(layer, remove_weight=True)
243

244
245
246
247
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
248
        bias: torch.Tensor | None = None,
249
    ) -> torch.Tensor:
250
        return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
251
252


253
class LinearBase(CustomOp):
254
    """Base linear layer.
255
256
257
258
259
260

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
261
        quant_config: Quantization configure.
262
        prefix: Prefix for parameter names.
263
        return_bias: If true, return bias together with outputs in forward pass.
264
        disable_tp: If true, tensor parallelism will be disabled for this layer.
265
266
267
268
269
270
271
    """

    def __init__(
        self,
        input_size: int,
        output_size: int,
        skip_bias_add: bool = False,
272
273
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
274
        prefix: str = "",
275
276
        *,
        return_bias: bool = True,
277
        disable_tp: bool = False,
278
279
280
281
282
283
284
285
286
287
    ):
        super().__init__()

        # Keep input parameters
        self.input_size = input_size
        self.output_size = output_size
        self.skip_bias_add = skip_bias_add
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype
288
289
        self.quant_config = quant_config
        self.prefix = prefix
290
        self.allow_fp8_block_shape_mismatch = False
291
        if quant_config is None:
292
            self.quant_method: QuantizeMethodBase | None = UnquantizedLinearMethod()
293
        else:
294
            self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
295
        self.return_bias = return_bias
296
        self.disable_tp = disable_tp
297
298
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
299

300
    def update_param_tp_status(self):
301
302
303
304
        for param in self.parameters():
            if isinstance(param, BasevLLMParameter):
                param.tp_rank = self.tp_rank
                param.tp_size = self.tp_size
305
306


307
@CustomOp.register("replicated_linear")
308
309
310
311
312
313
314
315
316
317
class ReplicatedLinear(LinearBase):
    """Replicated linear layer.

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        bias: If true, add bias.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
        quant_config: Quantization configure.
318
319
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
320
        return_bias: If true, return bias together with outputs in forward pass.
321
        disable_tp: Take no effect for replicated linear layers.
322
323
    """

324
325
326
327
328
329
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        skip_bias_add: bool = False,
330
331
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
332
333
334
        prefix: str = "",
        *,
        return_bias: bool = True,
335
        disable_tp: bool = False,
336
    ):
337
338
339
340
341
342
        # If MergedReplicatedLinear, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = self.output_sizes
        else:
            self.output_partition_sizes = [output_size]

343
344
345
346
347
348
349
350
351
352
        super().__init__(
            input_size,
            output_size,
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
353

354
355
        # All the linear layer supports quant method.
        assert self.quant_method is not None
356
357
358
359
360
361
362
363
364
        self.quant_method.create_weights(
            self,
            self.input_size,
            self.output_partition_sizes,
            self.input_size,
            self.output_size,
            self.params_dtype,
            weight_loader=self.weight_loader,
        )
365

366
367
        if bias:
            self.bias = Parameter(
368
369
370
371
372
373
374
375
376
                torch.empty(self.output_size, dtype=self.params_dtype)
            )
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
377
378
379
        else:
            self.register_parameter("bias", None)

380
381
382
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        # If the weight on disk does not have a shape, give it one
        # (such scales for AutoFp8).
383
384
385
386
387
388
389
390
391
392
393
        # Special case for GGUF

        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
            param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)

394
395
396
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)

397
398
        assert param.size() == loaded_weight.size(), (
            f"Tried to load weights of size {loaded_weight.size()}"
399
400
            f"to a parameter of size {param.size()}"
        )
401
402
        param.data.copy_(loaded_weight)

403
    def forward(
404
405
        self,
        x: torch.Tensor,
406
    ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
407
        bias = self.bias if not self.skip_bias_add else None
408
        assert self.quant_method is not None
409

410
        output = self.quant_method.apply(self, x, bias)
411
        output_bias = self.bias if self.skip_bias_add else None
412

413
414
        if not self.return_bias:
            return output
415
416
        return output, output_bias

417
418
419
420
421
422
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        return s

423

424
@CustomOp.register("column_parallel_linear")
425
class ColumnParallelLinear(LinearBase):
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
    """Linear layer with column parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its second dimension as A = [A_1, ..., A_p].

    Args:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y_i = XA_i
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
442
        quant_config: Quantization configure.
James Fleming's avatar
James Fleming committed
443
444
        output_sizes: list of output sizes packed into one output, like for QKV
                       the list would be size 3.
445
        prefix: The name of the layer in the state dict, including all parents
446
                        (e.g. model.layers.0.qkv_proj)
447
448
        return_bias: If true, return bias together with outputs in forward pass.
        disable_tp: If true, weights matrix won't be sharded through tp rank.
449
450
    """

451
452
453
454
455
456
457
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
458
459
460
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
        output_sizes: list[int] | None = None,
461
462
463
        prefix: str = "",
        *,
        return_bias: bool = True,
464
        disable_tp: bool = False,
465
    ):
466
        # Divide the weight matrix along the last dimension.
467
468
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
469
470
        self.input_size_per_partition = input_size
        self.output_size_per_partition = divide(output_size, self.tp_size)
471
472
473
474
        self.output_partition_sizes = [self.output_size_per_partition]
        # If QKV or MergedColumn, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = [
475
                divide(output_size, self.tp_size) for output_size in self.output_sizes
476
477
            ]

478
479
480
481
482
483
484
485
486
487
        super().__init__(
            input_size,
            output_size,
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
488

489
        self._maybe_allow_fp8_block_shape_mismatch()
490
491
        self.gather_output = gather_output

James Fleming's avatar
James Fleming committed
492
493
        if output_sizes is None:
            output_sizes = [output_size]
494

495
        assert self.quant_method is not None
496
497
        self.quant_method.create_weights(
            layer=self,
498
            input_size_per_partition=self.input_size_per_partition,
499
500
501
502
            output_partition_sizes=self.output_partition_sizes,
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
503
            weight_loader=(
504
505
506
507
508
                self.weight_loader_v2
                if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
                else self.weight_loader
            ),
        )
509
510
        if bias:
            self.bias = Parameter(
511
512
513
514
515
516
517
518
519
                torch.empty(self.output_size_per_partition, dtype=params_dtype)
            )
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
520
521
        else:
            self.register_parameter("bias", None)
522
        self.update_param_tp_status()
523

524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
    def _maybe_allow_fp8_block_shape_mismatch(self) -> None:
        quant_config = getattr(self, "quant_config", None)
        weight_block = getattr(quant_config, "weight_block_size", None)
        if (
            weight_block is None
            or len(weight_block) < 1
            or len(self.output_partition_sizes) <= 1
        ):
            return

        try:
            block_n = int(weight_block[0])
        except (ValueError, TypeError):
            return

        if block_n <= 0:
            return

        if any(size % block_n != 0 for size in self.output_partition_sizes):
            self.allow_fp8_block_shape_mismatch = True
            logger.debug(
                "Allowing FP8 block shape mismatch for %s (block_n=%d, partitions=%s)",
                getattr(self, "prefix", "<unknown>"),
                block_n,
                self.output_partition_sizes,
            )

551
552
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        output_dim = getattr(param, "output_dim", None)
553

554
555
556
557
558
559
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

560
561
562
563
564
565
566
567
        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
568
569
            final_shape = list(loaded_weight.shape)
            if output_dim is not None:
570
                assert final_shape[output_dim] % self.tp_size == 0
571
                final_shape[output_dim] = final_shape[output_dim] // self.tp_size
572
            param.materialize(final_shape, dtype=loaded_weight.dtype)
573

574
        param_data = param.data
575
        if output_dim is not None and not is_sharded_weight:
576
            shard_size = param_data.shape[output_dim]
577
            start_idx = self.tp_rank * shard_size
578
            loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
579
580
581
582
583

        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)
584

585
586
587
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

588
    def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor):
589
590
591
592
593
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)
594
595
        param.load_column_parallel_weight(loaded_weight=loaded_weight)

596
    def forward(
597
598
        self,
        input_,
599
    ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
600
601
602
        bias = self.bias if not self.skip_bias_add else None

        # Matrix multiply.
603
        assert self.quant_method is not None
604
        output_parallel = self.quant_method.apply(self, input_, bias)
605

606
        if self.gather_output and self.tp_size > 1:
607
608
609
610
611
            # All-gather across the partitions.
            output = tensor_model_parallel_all_gather(output_parallel)
        else:
            output = output_parallel
        output_bias = self.bias if self.skip_bias_add else None
612
613
        if not self.return_bias:
            return output
614
615
        return output, output_bias

616
617
618
619
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size_per_partition}"
        s += f", bias={self.bias is not None}"
620
        s += f", tp_size={self.tp_size}"
621
622
623
        s += f", gather_output={self.gather_output}"
        return s

624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642

class MergedColumnParallelLinear(ColumnParallelLinear):
    """Packed linear layers with column parallelism.

    Similar to ColumnParallelLinear, but the weight matrix is concatenated
    along the output dimension. When the weight matrix is loaded, the
    different partitions are sharded separately.

    Args:
        input_size: input dimension of the linear layer.
        output_sizes: list of output dimensions of the linear layer.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make the output
                       available to all GPUs, otherwise, every GPU will have
                       its own output.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
643
        quant_config: Quantization configure.
644
645
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
646
        return_bias: If true, return bias together with outputs in forward pass.
647
648
        disable_tp: If true, all weights matrix won't be sharded, this layer
                    will be treated as a "Replicated" MergedLinear.
649
650
    """

651
652
653
654
655
656
657
    def __init__(
        self,
        input_size: int,
        output_sizes: list[int],
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
658
659
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
660
661
662
        prefix: str = "",
        *,
        return_bias: bool = True,
663
        disable_tp: bool = False,
664
    ):
665
        self.output_sizes = output_sizes
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0

        assert all(output_size % self.tp_size == 0 for output_size in output_sizes)
        super().__init__(
            input_size=input_size,
            output_size=sum(output_sizes),
            bias=bias,
            gather_output=gather_output,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            quant_config=quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
James Fleming's avatar
James Fleming committed
682

683
684
685
686
    def weight_loader(
        self,
        param: Parameter,
        loaded_weight: torch.Tensor,
687
        loaded_shard_id: int | None = None,
688
    ):
689
690
691
692
693
        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
694
695
696
697
698
            if loaded_shard_id is not None:
                param.data[loaded_shard_id].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
                param.shard_weight_type = {
699
                    i: loaded_weight.item() for i, _ in enumerate(self.output_sizes)
700
                }
701
702
            return

703
704
        if is_gguf_weight:
            output_dim = getattr(param, "output_dim", None)
705
706
            shard_size = loaded_weight.size(output_dim) // self.tp_size
            start_idx = self.tp_rank * shard_size
707

708
            if loaded_shard_id is not None:
709
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
710
711
712
713
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
714

715
716
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
717
718
        # Special case for per-tensor scale to load scalar into fused array.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
719

720
        if loaded_shard_id is None:
721
722
            # Loaded weight is already fused on disk (mlp).
            # (e.g., Phi-3's gate_up_proj).
723
            if output_dim is None:
724
                if needs_scalar_to_array:
725
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
726
727
                        param_data, loaded_weight, 0
                    )
728

729
730
731
732
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            current_shard_offset = 0
733
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
734
            shard_offsets: list[tuple[int, int, int]] = []
735
736
737
738
739
            for i, output_size in enumerate(self.output_sizes):
                shard_offsets.append((i, current_shard_offset, output_size))
                current_shard_offset += output_size
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
740
                # Special case for Quantization.
741
742
743
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
744
745
                    shard_size = shard_size // param.packed_factor
                    shard_offset = shard_offset // param.packed_factor
746
                    # Special case for Marlin.
747
                    shard_size, shard_offset = adjust_marlin_shard(
748
749
                        param, shard_size, shard_offset
                    )
750

751
                shard_size, shard_offset = adjust_bitblas_shard(
752
753
                    param, shard_size, shard_offset
                )
754

755
                if use_bitsandbytes_4bit:
756
757
758
759
760
761
762
                    index = list(itertools.accumulate([0] + self.output_sizes))
                    orig_offsets = {
                        str(i): (index[i], size)
                        for i, size in enumerate(self.output_sizes)
                    }
                    orig_offsets["total"] = (self.output_size, 0)
                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
763
764
                        param, orig_offsets, str(shard_id)
                    )
765

766
                loaded_weight_shard = loaded_weight.narrow(
767
768
                    output_dim, shard_offset, shard_size
                )
769
770
771
772
773
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id < len(self.output_sizes)
        if output_dim is not None:
774
775
776
777
778
779
780
781
782
783
784
785
            shard_offset = sum(self.output_sizes[:loaded_shard_id])
            shard_size = self.output_sizes[loaded_shard_id]

            if isinstance(param, BlockQuantScaleParameter):
                weight_block_size = getattr(self, "weight_block_size", None)
                shard_size, shard_offset = adjust_block_scale_shard(
                    weight_block_size, shard_size, shard_offset
                )

            shard_offset //= self.tp_size
            shard_size //= self.tp_size

786
            # Special case for quantization.
787
788
789
790
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
791
792
                shard_size = shard_size // param.packed_factor
                shard_offset = shard_offset // param.packed_factor
793
                # Special case for Marlin.
794
                shard_size, shard_offset = adjust_marlin_shard(
795
796
                    param, shard_size, shard_offset
                )
797
            shard_size, shard_offset = adjust_bitblas_shard(
798
799
                param, shard_size, shard_offset
            )
800

801
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
802
803
804
805
806
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

807
            if use_bitsandbytes_4bit:
808
                shard_size = loaded_weight.shape[output_dim]
809
                shard_offset = loaded_weight.shape[output_dim] * loaded_shard_id
810

811
            param_data = param_data.narrow(output_dim, shard_offset, shard_size)
812
            start_idx = self.tp_rank * shard_size
813
            if not is_sharded_weight:
814
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
815
816
817
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
818
819
                param_data, loaded_weight, loaded_shard_id
            )
820

821
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
822
823
824
825
826
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "MergedColumnParallelLinear, assume the weight is "
827
828
                    "the same for all partitions."
                )
829

830
831
832
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

833
834
835
    def _load_fused_module_from_checkpoint(
        self, param: BasevLLMParameter, loaded_weight: torch.Tensor
    ):
836
837
838
        """
        Handle special case for models where MLP layers are already
        fused on disk. In this case, we have no shard id. This function
839
        determines the shard id by splitting these layers and then calls
840
841
842
843
844
845
846
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """

        current_shard_offset = 0
847
        shard_offsets: list[tuple[int, int, int]] = []
848
849
850
851
852
853
854
855
        for i, output_size in enumerate(self.output_sizes):
            shard_offsets.append((i, current_shard_offset, output_size))
            current_shard_offset += output_size

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
856
857
858
859
860
861
862
863
864
865
866
            if (
                isinstance(param, (PackedColumnParameter, PackedvLLMParameter))
                and param.packed_dim == param.output_dim
            ):
                shard_size, shard_offset = param.adjust_shard_indexes_for_packing(
                    shard_size=shard_size, shard_offset=shard_offset
                )

            loaded_weight_shard = loaded_weight.narrow(
                param.output_dim, shard_offset, shard_size
            )
867
868
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

869
870
871
872
    def weight_loader_v2(
        self,
        param: BasevLLMParameter,
        loaded_weight: torch.Tensor,
873
        loaded_shard_id: int | None = None,
874
    ):
875
        if loaded_shard_id is None:
876
            if isinstance(param, PerTensorScaleParameter):
877
                param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0)
878
                return
879
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
880
                param.load_merged_column_weight(loaded_weight=loaded_weight)
881
                return
882
            # TODO: @dsikka - move to parameter.py
883
884
885
886
887
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id < len(self.output_sizes)

888
889
890
        shard_offset = sum(self.output_sizes[:loaded_shard_id])
        shard_size = self.output_sizes[loaded_shard_id]

891
        if isinstance(param, BlockQuantScaleParameter):
892
893
894
            weight_block_size = getattr(self, "weight_block_size", None)
            shard_size, shard_offset = adjust_block_scale_shard(
                weight_block_size, shard_size, shard_offset
895
            )
896
897
898

        shard_offset //= self.tp_size
        shard_size //= self.tp_size
899

900
901
902
903
904
905
906
        param.load_merged_column_weight(
            loaded_weight=loaded_weight,
            shard_id=loaded_shard_id,
            shard_offset=shard_offset,
            shard_size=shard_size,
            tp_rank=self.tp_rank,
        )
907

908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929

class QKVParallelLinear(ColumnParallelLinear):
    """Linear layers for the attention's QKV transformation.

    Linear layers for the linear transformation of the query, key, and value
    vectors in the attention layer. The weight matrix is concatenated along
    the output dimension. The layer is parallelized along the head dimension.
    When the number of key/value heads is smaller than the number of query
    heads (e.g., multi-query/grouped-query attention), the key/value head may
    be replicated while the query heads are partitioned.

    Args:
        hidden_size: input hidden state size of the transformer.
        head_size: size of each attention head.
        total_num_heads: total number of attention query heads.
        total_num_kv_heads: total number of attention key/value heads. If
                            None, assume total_num_kv_heads = total_num_heads.
        bias: If true, add bias.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
930
        quant_config: Quantization configure.
931
932
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
933
        return_bias: If true, return bias together with outputs in forward pass.
934
        disable_tp: If true, weights matrix won't be sharded through tp rank.
935
936
    """

937
938
939
940
941
    def __init__(
        self,
        hidden_size: int,
        head_size: int,
        total_num_heads: int,
942
        total_num_kv_heads: int | None = None,
943
944
        bias: bool = True,
        skip_bias_add: bool = False,
945
946
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
947
948
949
        prefix: str = "",
        *,
        return_bias: bool = True,
950
        disable_tp: bool = False,
951
        v_head_size: int | None = None,
952
    ):
953
954
        self.hidden_size = hidden_size
        self.head_size = head_size
955
        self.v_head_size = v_head_size if v_head_size is not None else head_size
956
957
958
959
960
        self.total_num_heads = total_num_heads
        if total_num_kv_heads is None:
            total_num_kv_heads = total_num_heads
        self.total_num_kv_heads = total_num_kv_heads
        # Divide the weight matrix along the last dimension.
961
        tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
962
963
964
        self.num_heads = divide(self.total_num_heads, tp_size)
        if tp_size >= self.total_num_kv_heads:
            self.num_kv_heads = 1
965
            self.num_kv_head_replicas = divide(tp_size, self.total_num_kv_heads)
966
967
968
969
        else:
            self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
            self.num_kv_head_replicas = 1
        input_size = self.hidden_size
970
        output_size = (
971
972
973
974
            self.num_heads * self.head_size
            + self.num_kv_heads * self.head_size
            + self.num_kv_heads * self.v_head_size
        ) * tp_size
975
976
977
        self.output_sizes = [
            self.num_heads * self.head_size * tp_size,  # q_proj
            self.num_kv_heads * self.head_size * tp_size,  # k_proj
978
            self.num_kv_heads * self.v_head_size * tp_size,  # v_proj
James Fleming's avatar
James Fleming committed
979
980
        ]

981
982
983
984
985
986
987
988
989
990
991
992
        super().__init__(
            input_size=input_size,
            output_size=output_size,
            bias=bias,
            gather_output=False,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            quant_config=quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
993

994
995
996
997
998
    def _get_shard_offset_mapping(self, loaded_shard_id: str):
        shard_offset_mapping = {
            "q": 0,
            "k": self.num_heads * self.head_size,
            "v": (self.num_heads + self.num_kv_heads) * self.head_size,
999
1000
            "total": (self.num_heads + self.num_kv_heads) * self.head_size
            + self.num_kv_heads * self.v_head_size,
1001
1002
1003
1004
1005
1006
1007
        }
        return shard_offset_mapping.get(loaded_shard_id)

    def _get_shard_size_mapping(self, loaded_shard_id: str):
        shard_size_mapping = {
            "q": self.num_heads * self.head_size,
            "k": self.num_kv_heads * self.head_size,
1008
            "v": self.num_kv_heads * self.v_head_size,
1009
1010
1011
        }
        return shard_size_mapping.get(loaded_shard_id)

1012
1013
1014
    def _load_fused_module_from_checkpoint(
        self, param: BasevLLMParameter, loaded_weight: torch.Tensor
    ):
1015
        """
1016
        Handle special case for models where QKV layers are already
1017
        fused on disk. In this case, we have no shard id. This function
1018
        determines the shard id by splitting these layers and then calls
1019
1020
1021
1022
1023
1024
1025
1026
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """
        shard_offsets = [
            # (shard_id, shard_offset, shard_size)
            ("q", 0, self.total_num_heads * self.head_size),
1027
1028
1029
1030
1031
1032
1033
1034
            (
                "k",
                self.total_num_heads * self.head_size,
                self.total_num_kv_heads * self.head_size,
            ),
            (
                "v",
                (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
1035
                self.total_num_kv_heads * self.v_head_size,
1036
            ),
1037
1038
1039
1040
1041
1042
        ]

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
            if (
                isinstance(param, (PackedColumnParameter, PackedvLLMParameter))
                and param.packed_dim == param.output_dim
            ):
                shard_size, shard_offset = param.adjust_shard_indexes_for_packing(
                    shard_size=shard_size, shard_offset=shard_offset
                )

            loaded_weight_shard = loaded_weight.narrow(
                param.output_dim, shard_offset, shard_size
            )
1054
1055
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

1056
1057
1058
1059
    def weight_loader_v2(
        self,
        param: BasevLLMParameter,
        loaded_weight: torch.Tensor,
1060
        loaded_shard_id: str | None = None,
1061
    ):
1062
        if loaded_shard_id is None:  # special case for certain models
1063
            if isinstance(param, PerTensorScaleParameter):
1064
1065
1066
                param.load_qkv_weight(
                    loaded_weight=loaded_weight, shard_id=0, tp_rank=self.tp_rank
                )
1067
                return
1068
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
1069
                param.load_qkv_weight(loaded_weight=loaded_weight, tp_rank=self.tp_rank)
1070
                return
1071
            # TODO: @dsikka - move to parameter.py
1072
1073
1074
1075
1076
1077
1078
1079
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id in ["q", "k", "v"]

        shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
        shard_size = self._get_shard_size_mapping(loaded_shard_id)

1080
        if isinstance(param, BlockQuantScaleParameter):
1081
1082
1083
1084
            weight_block_size = getattr(self, "weight_block_size", None)
            shard_size, shard_offset = adjust_block_scale_shard(
                weight_block_size, shard_size, shard_offset
            )
1085

1086
1087
1088
1089
1090
1091
1092
1093
        param.load_qkv_weight(
            loaded_weight=loaded_weight,
            num_heads=self.num_kv_head_replicas,
            shard_id=loaded_shard_id,
            shard_offset=shard_offset,
            shard_size=shard_size,
            tp_rank=self.tp_rank,
        )
1094

1095
1096
1097
1098
    def weight_loader(
        self,
        param: Parameter,
        loaded_weight: torch.Tensor,
1099
        loaded_shard_id: str | None = None,
1100
    ):
1101
1102
1103
1104
        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
1105
        if is_gguf_weight_type:
1106
            idx_map = {"q": 0, "k": 1, "v": 2}
1107
1108
1109
1110
            if loaded_shard_id is not None:
                param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
1111
                param.shard_weight_type = {k: loaded_weight.item() for k in idx_map}
1112
1113
            return

1114
1115
        if is_gguf_weight:
            output_dim = getattr(param, "output_dim", None)
1116
1117
            shard_size = loaded_weight.size(output_dim) // self.tp_size
            start_idx = self.tp_rank * shard_size
1118

1119
            if loaded_shard_id is not None:
1120
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
1121
1122
1123
1124
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
1125

1126
1127
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
1128

1129
1130
        # Special case for per-tensor scales in fused case.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
1131

1132
        if loaded_shard_id is None:
1133
1134
            # Loaded weight is already fused on disk (qkv).
            # (e.g., Phi-3's qkv_proj).
1135
            if output_dim is None:
1136
                if needs_scalar_to_array:
1137
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
1138
1139
                        param_data, loaded_weight, 0
                    )
1140

1141
1142
1143
1144
1145
1146
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            shard_offsets = [
                # (shard_id, shard_offset, shard_size)
                ("q", 0, self.total_num_heads * self.head_size),
1147
1148
1149
1150
1151
1152
1153
1154
                (
                    "k",
                    self.total_num_heads * self.head_size,
                    self.total_num_kv_heads * self.head_size,
                ),
                (
                    "v",
                    (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
1155
                    self.total_num_kv_heads * self.v_head_size,
1156
                ),
1157
            ]
1158
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1159

1160
1161
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
1162
                # Special case for Quantized Weights.
1163
1164
1165
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
1166
1167
                    shard_size = shard_size // param.packed_factor
                    shard_offset = shard_offset // param.packed_factor
1168

1169
                    # Special case for Marlin.
1170
                    shard_size, shard_offset = adjust_marlin_shard(
1171
1172
                        param, shard_size, shard_offset
                    )
1173

1174
1175
1176
                if use_bitsandbytes_4bit:
                    orig_qkv_offsets = {
                        "q": (0, self.total_num_heads * self.head_size),
1177
1178
1179
1180
1181
1182
1183
                        "k": (
                            self.total_num_heads * self.head_size,
                            self.total_num_kv_heads * self.head_size,
                        ),
                        "v": (
                            (self.total_num_heads + self.total_num_kv_heads)
                            * self.head_size,
1184
                            self.total_num_kv_heads * self.v_head_size,
1185
1186
                        ),
                        "total": (
1187
1188
1189
                            (self.total_num_heads + self.total_num_kv_heads)
                            * self.head_size
                            + self.total_num_kv_heads * self.v_head_size,
1190
1191
                            0,
                        ),
1192
1193
1194
                    }

                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
1195
1196
                        param, orig_qkv_offsets, shard_id
                    )
1197

1198
                loaded_weight_shard = loaded_weight.narrow(
1199
1200
                    output_dim, shard_offset, shard_size
                )
1201
1202
1203
1204
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id in ["q", "k", "v"]
1205
1206

        # If output dim is defined, use the default loading process.
1207
1208
1209
1210
1211
1212
1213
1214
        if output_dim is not None:
            if loaded_shard_id == "q":
                shard_offset = 0
                shard_size = self.num_heads * self.head_size
            elif loaded_shard_id == "k":
                shard_offset = self.num_heads * self.head_size
                shard_size = self.num_kv_heads * self.head_size
            elif loaded_shard_id == "v":
1215
                shard_offset = (self.num_heads + self.num_kv_heads) * self.head_size
1216
                shard_size = self.num_kv_heads * self.v_head_size
1217
1218
1219
1220
1221
1222
1223

            if isinstance(param, BlockQuantScaleParameter):
                weight_block_size = getattr(self, "weight_block_size", None)
                shard_size, shard_offset = adjust_block_scale_shard(
                    weight_block_size, shard_size, shard_offset
                )

1224
            # Special case for Quantized Weights.
1225
1226
1227
1228
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
1229
1230
                shard_size = shard_size // param.packed_factor
                shard_offset = shard_offset // param.packed_factor
1231

1232
                # Special case for Marlin.
1233
                shard_size, shard_offset = adjust_marlin_shard(
1234
1235
                    param, shard_size, shard_offset
                )
1236

1237
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1238
1239
1240
1241
1242
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

1243
            if use_bitsandbytes_4bit:
1244
1245
                orig_qkv_offsets = {
                    "q": (0, self.num_heads * self.head_size),
1246
1247
1248
1249
1250
1251
                    "k": (
                        self.num_heads * self.head_size,
                        self.num_kv_heads * self.head_size,
                    ),
                    "v": (
                        (self.num_heads + self.num_kv_heads) * self.head_size,
1252
                        self.num_kv_heads * self.v_head_size,
1253
1254
                    ),
                    "total": (
1255
1256
                        (self.num_heads + self.num_kv_heads) * self.head_size
                        + self.num_kv_heads * self.v_head_size,
1257
1258
                        0,
                    ),
1259
                }
1260
                shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
1261
1262
                    param, orig_qkv_offsets, loaded_shard_id
                )
1263

1264
            param_data = param_data.narrow(output_dim, shard_offset, shard_size)
1265
            if loaded_shard_id == "q":
1266
                shard_rank = self.tp_rank
1267
            else:
1268
1269
                shard_rank = self.tp_rank // self.num_kv_head_replicas
            start_idx = shard_rank * shard_size
1270

1271
            if not is_sharded_weight:
1272
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
1273

1274
1275
1276
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
1277
1278
                param_data, loaded_weight, loaded_shard_id
            )
1279
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
1280
1281
1282
1283
1284
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "QKVParallelLinear, assume the weight is the same "
1285
1286
                    "for all partitions."
                )
1287

1288
1289
1290
1291
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)


1292
@CustomOp.register("row_parallel_linear")
1293
class RowParallelLinear(LinearBase):
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
    """Linear layer with row parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its first dimension and X along its second dimension as:
               -   -
              | A_1 |
              | .   |
          A = | .   |        X = [X_1, ..., X_p]
              | .   |
              | A_p |
               -   -
    Arguments:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias. Note that bias is not parallelized.
        input_is_parallel: If true, we assume that the input is already
                           split across the GPUs and we do not split
                           again.
        skip_bias_add: This was added to enable performance optimization where
                       bias can be fused with other element-wise operations.
                       We skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
1316
1317
1318
        reduce_results: If true, call all-reduce on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y = X_iA_i
1319
        quant_config: Quantization configure.
1320
1321
1322
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.down_proj)
        return_bias: If true, return bias together with outputs in forward pass.
1323
        disable_tp: If true, weights matrix won't be sharded through tp rank.
1324
1325
    """

1326
1327
1328
1329
1330
1331
1332
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        input_is_parallel: bool = True,
        skip_bias_add: bool = False,
1333
        params_dtype: torch.dtype | None = None,
1334
        reduce_results: bool = True,
1335
        quant_config: QuantizationConfig | None = None,
1336
1337
1338
        prefix: str = "",
        *,
        return_bias: bool = True,
1339
        disable_tp: bool = False,
1340
    ):
1341
        # Divide the weight matrix along the first dimension.
1342
1343
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
1344
1345
1346
1347
        self.input_size_per_partition = divide(input_size, self.tp_size)
        self.output_size_per_partition = output_size
        self.output_partition_sizes = [output_size]

1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
        super().__init__(
            input_size,
            output_size,
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
1358

1359
1360
1361
        self.input_is_parallel = input_is_parallel
        self.reduce_results = reduce_results

1362
        assert self.quant_method is not None
1363
1364
1365
        self.quant_method.create_weights(
            layer=self,
            input_size_per_partition=self.input_size_per_partition,
1366
            output_partition_sizes=self.output_partition_sizes,
1367
1368
1369
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
1370
            weight_loader=(
1371
1372
1373
1374
1375
                self.weight_loader_v2
                if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
                else self.weight_loader
            ),
        )
1376
        if not reduce_results and (bias and not skip_bias_add):
1377
1378
1379
1380
            raise ValueError(
                "When not reduce the results, adding bias to the "
                "results can lead to incorrect results"
            )
1381
1382

        if bias:
1383
1384
1385
1386
1387
1388
1389
1390
            self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype))
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
1391
1392
        else:
            self.register_parameter("bias", None)
1393
        self.update_param_tp_status()
1394
1395
1396

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        input_dim = getattr(param, "input_dim", None)
1397
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1398
1399
1400
1401
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
1402
1403
1404
1405
1406
1407
1408
1409
1410

        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
1411
1412
            weight_shape = list(loaded_weight.shape)
            if input_dim:
1413
                weight_shape[input_dim] = weight_shape[input_dim] // self.tp_size
1414
            param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
1415

1416
        param_data = param.data
1417
        if input_dim is not None and not is_sharded_weight:
1418
            shard_size = param_data.shape[input_dim]
1419
            start_idx = self.tp_rank * shard_size
1420
            loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)
1421

1422
1423
1424
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
1425
1426
            loaded_weight = loaded_weight.reshape(1)

1427
1428
1429
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

1430
    def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor):
1431
1432
1433
1434
1435
1436
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)

1437
1438
        param.load_row_parallel_weight(loaded_weight=loaded_weight)

1439
    def forward(
1440
1441
        self,
        input_,
1442
    ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
1443
1444
1445
1446
        if self.input_is_parallel:
            input_parallel = input_
        else:
            splitted_input = split_tensor_along_last_dim(
1447
1448
                input_, num_partitions=self.tp_size
            )
1449
            input_parallel = splitted_input[self.tp_rank].contiguous()
1450
1451

        # Matrix multiply.
1452
        assert self.quant_method is not None
1453
1454
1455
        # Only fuse bias add into GEMM for rank 0 (this ensures that
        # bias will not get added more than once in TP>1 case)
        bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
1456
1457
        output_parallel = self.quant_method.apply(self, input_parallel, bias_)

1458
        if self.reduce_results and self.tp_size > 1:
1459
            output = tensor_model_parallel_all_reduce(output_parallel)
1460
        else:
1461
1462
1463
            output = output_parallel

        output_bias = self.bias if self.skip_bias_add else None
1464

1465
1466
        if not self.return_bias:
            return output
1467
        return output, output_bias
1468
1469

    def extra_repr(self) -> str:
1470
        s = f"in_features={self.input_size_per_partition}"
1471
1472
1473
1474
1475
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={self.tp_size}"
        s += f", reduce_results={self.reduce_results}"
        return s