"vscode:/vscode.git/clone" did not exist on "68ffbca7e462cfa6a32b46dabc9a604c7c1b918d"
linear.py 60.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
from abc import abstractmethod
6
from typing import Any
7
8

import torch
9
from torch.nn.parameter import Parameter, UninitializedParameter
10

11
12
13
14
15
16
17
18
from vllm.distributed import (
    divide,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
    split_tensor_along_last_dim,
    tensor_model_parallel_all_gather,
    tensor_model_parallel_all_reduce,
)
19
from vllm.logger import init_logger
20
from vllm.model_executor.custom_op import PluggableLayer
21
22
23
24
from vllm.model_executor.layers.batch_invariant import (
    linear_batch_invariant,
    vllm_is_batch_invariant,
)
25
from vllm.model_executor.layers.quantization.base_config import (
26
27
28
    QuantizationConfig,
    QuantizeMethodBase,
)
29
30
31
from vllm.model_executor.layers.utils import (
    dispatch_unquantized_gemm,
)
32
33
34
35
36
37
38
39
40
from vllm.model_executor.parameter import (
    BasevLLMParameter,
    BlockQuantScaleParameter,
    ModelWeightParameter,
    PackedColumnParameter,
    PackedvLLMParameter,
    PerTensorScaleParameter,
    RowvLLMParameter,
)
41
from vllm.model_executor.utils import set_weight_attrs
42
from vllm.platforms import current_platform
43
44
45

logger = init_logger(__name__)

46
WEIGHT_LOADER_V2_SUPPORTED = [
47
    "UnquantizedLinearMethod",
48
    "CompressedTensorsLinearMethod",
49
    "CompressedTensorsLinearTransformMethod",
50
51
52
53
54
55
56
57
58
59
    "AWQMarlinLinearMethod",
    "AWQLinearMethod",
    "GPTQMarlinLinearMethod",
    "Fp8LinearMethod",
    "MarlinLinearMethod",
    "GPTQMarlin24LinearMethod",
    "TPUInt8LinearMethod",
    "GPTQLinearMethod",
    "FBGEMMFp8LinearMethod",
    "ModelOptFp8LinearMethod",
60
61
    "ModelOptFp8PcPtLinearMethod",
    "ModelOptFp8PbWoLinearMethod",
62
63
    "QuarkLinearMethod",
    "ModelOptNvFp4LinearMethod",
64
    "PetitNvFp4LinearMethod",
65
]
66

67

68
69
70
71
72
73
def adjust_marlin_shard(
    param: Parameter,
    shard_size: int,
    shard_offset: int,
) -> tuple[int, int]:
    marlin_tile_size: int | None = getattr(param, "marlin_tile_size", None)
74
75
76
77
78
79
    if marlin_tile_size is None:
        return shard_size, shard_offset

    return shard_size * marlin_tile_size, shard_offset * marlin_tile_size


80
81
82
83
84
def adjust_block_scale_shard(
    weight_block_size: tuple[int, ...] | None,
    shard_size: int,
    shard_offset: int,
) -> tuple[int, int]:
85
86
87
88
89
90
91
    assert weight_block_size is not None
    block_n = weight_block_size[0]
    shard_offset = (shard_offset + block_n - 1) // block_n
    shard_size = (shard_size + block_n - 1) // block_n
    return shard_size, shard_offset


92
def adjust_bitsandbytes_4bit_shard(
93
94
95
    param: Parameter,
    shard_offsets: dict[str, tuple[int, int]],
    loaded_shard_id: str,
96
) -> tuple[int, int]:
97
98
    """Adjust the quantization offsets and sizes for BitsAndBytes sharding."""

99
100
    total, _ = shard_offsets["total"]
    orig_offset, orig_size = shard_offsets[loaded_shard_id]
101
102
103
104
105
106
107
108

    quantized_total = param.data.shape[0]
    quantized_offset = orig_offset * quantized_total // total
    quantized_size = orig_size * quantized_total // total

    return quantized_size, quantized_offset


109
110
111
112
113
def adjust_scalar_to_fused_array(
    param_data: torch.Tensor,
    loaded_weight: torch.Tensor,
    shard_id: int | str,
) -> tuple[torch.Tensor, torch.Tensor]:
114
115
    """For fused modules (QKV and MLP) we have an array of length
    N that holds 1 scale for each "logical" matrix. So the param
116
117
    is an array of length N. The loaded_weight corresponds to
    one of the shards on disk. Here, we slice the param based on
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
    the shard_id for loading.
    """
    qkv_idxs = {"q": 0, "k": 1, "v": 2}

    if isinstance(shard_id, str):
        shard_id = qkv_idxs[shard_id]
    elif not isinstance(shard_id, int):
        raise ValueError(f"Unknown Shard Id {shard_id}")

    # AutoFP8 scales do not have a shape
    # compressed-tensors scales do have a shape
    if len(loaded_weight.shape) != 0:
        assert loaded_weight.shape[0] == 1
        loaded_weight = loaded_weight[0]

133
    return param_data[shard_id], loaded_weight
134
135


136
137
# TODO(Isotr0py): We might need a more flexible structure to handle
# bitsandbytes shard offsets.
138
139
140
def left_shift_bitsandbytes_4bit_shard(
    bnb_weight_attrs: dict[str, Any],
) -> tuple[dict[str, Any], dict[str, Any]]:
141
142
143
144
145
    """
    Separate the BitsAndBytes 4-bit shard.

    For example, given bnb weight attributes as below:
    {
146
        'bnb_shard_offsets': array([0, 4, 8, 16]),
147
148
149
150
151
        'bnb_quant_state': {0: ..., 1: ..., 2: ...},
    }

    The function will return:
    {
152
        'bnb_shard_offsets': array([0, 4]),
153
154
155
156
157
158
159
160
161
162
163
164
165
166
        'bnb_quant_state': {0: ...},
    }
    and
    {
        'bnb_shard_offsets': array([0, 4, 12]),
        'bnb_quant_state': {0: ..., 1: ...},
    }
    """
    shard_offsets = bnb_weight_attrs["bnb_shard_offsets"]
    offset_l = shard_offsets[:2]
    offset_r = shard_offsets[1:] - shard_offsets[1]
    quant_state_l = {0: bnb_weight_attrs["bnb_quant_state"][0]}
    quant_state_r = {
        i - 1: bnb_weight_attrs["bnb_quant_state"][i]
167
        for i in range(1, len(shard_offsets) - 1)
168
169
170
171
172
173
    }
    left = dict(bnb_shard_offsets=offset_l, bnb_quant_state=quant_state_l)
    right = dict(bnb_shard_offsets=offset_r, bnb_quant_state=quant_state_r)
    return left, right


174
class LinearMethodBase(QuantizeMethodBase):
175
176
177
    """Base class for different (maybe quantized) linear methods."""

    @abstractmethod
178
179
180
181
182
183
184
185
186
187
188
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        """Create weights for a linear layer.
189
           The weights will be set as attributes of the layer.
190

191
192
193
        Args:
            layer: The layer that is using the LinearMethodBase factory.
            input_size_per_partition: Size of the weight input dim on rank X.
194
            output_partition_sizes: Sizes of the output dim of each logical
195
196
197
198
199
200
                weight on rank X. E.g., output_partition_sizes for QKVLinear
                is a list contains the width of Wq, Wk, Wv on rank X.
            input_size: Size of the input dim of the weight across all ranks.
            output_size: Size of the output dim of the weight across all ranks.
            params_dtype: Datatype of the parameters.
        """
201
202
203
        raise NotImplementedError

    @abstractmethod
204
205
206
207
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
208
        bias: torch.Tensor | None = None,
209
    ) -> torch.Tensor:
210
211
        """Apply the weights in layer to the input tensor.
        Expects create_weights to have been called before on the layer."""
212
213
214
215
        raise NotImplementedError


class UnquantizedLinearMethod(LinearMethodBase):
216
    """Linear method without quantization."""
217

218
219
220
221
222
223
224
225
226
227
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
228
229
230
231
        # This method creates unquantized linear weights.
        # The weights are not quantized, and they are not sharded.
        # The amount of memory allocated for the weights is
        # sum(output_partition_sizes) * input_size_per_partition.
232
233
234
235
236
237
238
239
240
241
242
        weight_loader = extra_weight_attrs.pop("weight_loader")
        weight = ModelWeightParameter(
            data=torch.empty(
                sum(output_partition_sizes),
                input_size_per_partition,
                dtype=params_dtype,
            ),
            input_dim=1,
            output_dim=0,
            weight_loader=weight_loader,
        )
243

244
245
        layer.register_parameter("weight", weight)
        set_weight_attrs(weight, extra_weight_attrs)
246

247
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
248
        if current_platform.is_cpu():
249
            from vllm.model_executor.layers.utils import dispatch_cpu_unquantized_gemm
250

251
            dispatch_cpu_unquantized_gemm(layer, remove_weight=True)
252

253
254
255
256
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
257
        bias: torch.Tensor | None = None,
258
    ) -> torch.Tensor:
259
        if vllm_is_batch_invariant() and current_platform.is_cuda_alike():
260
            return linear_batch_invariant(x, layer.weight, bias)
261
        return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
262
263


264
class LinearBase(PluggableLayer):
265
    """Base linear layer.
266
267
268
269
270
271

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
272
        quant_config: Quantization configure.
273
        prefix: Prefix for parameter names.
274
        return_bias: If true, return bias together with outputs in forward pass.
275
        disable_tp: If true, tensor parallelism will be disabled for this layer.
276
277
278
279
280
281
282
    """

    def __init__(
        self,
        input_size: int,
        output_size: int,
        skip_bias_add: bool = False,
283
284
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
285
        prefix: str = "",
286
287
        *,
        return_bias: bool = True,
288
        disable_tp: bool = False,
289
290
291
292
293
294
295
296
297
298
    ):
        super().__init__()

        # Keep input parameters
        self.input_size = input_size
        self.output_size = output_size
        self.skip_bias_add = skip_bias_add
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype
299
300
        self.quant_config = quant_config
        self.prefix = prefix
301
        self.allow_fp8_block_shape_mismatch = False
302
        if quant_config is None:
303
            self.quant_method: QuantizeMethodBase | None = UnquantizedLinearMethod()
304
        else:
305
            self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
306
        self.return_bias = return_bias
307
        self.disable_tp = disable_tp
308
309
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
310

311
    def update_param_tp_status(self):
312
313
314
315
        for param in self.parameters():
            if isinstance(param, BasevLLMParameter):
                param.tp_rank = self.tp_rank
                param.tp_size = self.tp_size
316
317


318
# --8<-- [start:replicated_linear]
319
@PluggableLayer.register("replicated_linear")
320
321
322
323
324
325
326
327
328
329
class ReplicatedLinear(LinearBase):
    """Replicated linear layer.

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        bias: If true, add bias.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
        quant_config: Quantization configure.
330
331
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
332
        return_bias: If true, return bias together with outputs in forward pass.
333
        disable_tp: Take no effect for replicated linear layers.
334
335
    """

336
337
    # --8<-- [end:replicated_linear]

338
339
340
341
342
343
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        skip_bias_add: bool = False,
344
345
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
346
347
348
        prefix: str = "",
        *,
        return_bias: bool = True,
349
        disable_tp: bool = False,
350
    ):
351
352
353
354
355
356
        # If MergedReplicatedLinear, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = self.output_sizes
        else:
            self.output_partition_sizes = [output_size]

357
358
359
360
361
362
363
364
365
366
        super().__init__(
            input_size,
            output_size,
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
367

368
369
        # All the linear layer supports quant method.
        assert self.quant_method is not None
370
371
372
373
374
375
376
377
378
        self.quant_method.create_weights(
            self,
            self.input_size,
            self.output_partition_sizes,
            self.input_size,
            self.output_size,
            self.params_dtype,
            weight_loader=self.weight_loader,
        )
379

380
381
        if bias:
            self.bias = Parameter(
382
383
384
385
386
387
388
389
390
                torch.empty(self.output_size, dtype=self.params_dtype)
            )
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
391
392
393
        else:
            self.register_parameter("bias", None)

394
395
396
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        # If the weight on disk does not have a shape, give it one
        # (such scales for AutoFp8).
397
398
399
400
401
402
403
404
405
406
407
        # Special case for GGUF

        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
            param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)

408
409
410
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)

411
412
        assert param.size() == loaded_weight.size(), (
            f"Tried to load weights of size {loaded_weight.size()}"
413
414
            f"to a parameter of size {param.size()}"
        )
415
416
        param.data.copy_(loaded_weight)

417
    def forward(
418
419
        self,
        x: torch.Tensor,
420
    ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
421
        bias = self.bias if not self.skip_bias_add else None
422
        assert self.quant_method is not None
423

424
        output = self.quant_method.apply(self, x, bias)
425

426
427
        if not self.return_bias:
            return output
428
        output_bias = self.bias if self.skip_bias_add else None
429
430
        return output, output_bias

431
432
433
434
435
436
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        return s

437

438
# --8<-- [start:column_parallel_linear]
439
@PluggableLayer.register("column_parallel_linear")
440
class ColumnParallelLinear(LinearBase):
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
    """Linear layer with column parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its second dimension as A = [A_1, ..., A_p].

    Args:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y_i = XA_i
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
457
        quant_config: Quantization configure.
458
        prefix: The name of the layer in the state dict, including all parents
459
                        (e.g. model.layers.0.qkv_proj)
460
461
        return_bias: If true, return bias together with outputs in forward pass.
        disable_tp: If true, weights matrix won't be sharded through tp rank.
462
463
    """

464
465
    # --8<-- [end:column_parallel_linear]

466
467
468
469
470
471
472
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
473
474
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
475
476
477
        prefix: str = "",
        *,
        return_bias: bool = True,
478
        disable_tp: bool = False,
479
    ):
480
        # Divide the weight matrix along the last dimension.
481
482
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
483
484
        self.input_size_per_partition = input_size
        self.output_size_per_partition = divide(output_size, self.tp_size)
485
486
487
488
        self.output_partition_sizes = [self.output_size_per_partition]
        # If QKV or MergedColumn, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = [
489
                divide(output_size, self.tp_size) for output_size in self.output_sizes
490
491
            ]

492
493
494
495
496
497
498
499
500
501
        super().__init__(
            input_size,
            output_size,
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
502

503
        self._maybe_allow_fp8_block_shape_mismatch()
504
505
506
        self.gather_output = gather_output

        assert self.quant_method is not None
507
508
        self.quant_method.create_weights(
            layer=self,
509
            input_size_per_partition=self.input_size_per_partition,
510
511
512
513
            output_partition_sizes=self.output_partition_sizes,
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
514
            weight_loader=(
515
516
517
518
519
                self.weight_loader_v2
                if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
                else self.weight_loader
            ),
        )
520
521
        if bias:
            self.bias = Parameter(
522
523
524
525
526
527
528
529
530
                torch.empty(self.output_size_per_partition, dtype=params_dtype)
            )
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
531
532
        else:
            self.register_parameter("bias", None)
533
        self.update_param_tp_status()
534

535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
    def _maybe_allow_fp8_block_shape_mismatch(self) -> None:
        quant_config = getattr(self, "quant_config", None)
        weight_block = getattr(quant_config, "weight_block_size", None)
        if (
            weight_block is None
            or len(weight_block) < 1
            or len(self.output_partition_sizes) <= 1
        ):
            return

        try:
            block_n = int(weight_block[0])
        except (ValueError, TypeError):
            return

        if block_n <= 0:
            return

        if any(size % block_n != 0 for size in self.output_partition_sizes):
            self.allow_fp8_block_shape_mismatch = True
            logger.debug(
                "Allowing FP8 block shape mismatch for %s (block_n=%d, partitions=%s)",
                getattr(self, "prefix", "<unknown>"),
                block_n,
                self.output_partition_sizes,
            )

562
563
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        output_dim = getattr(param, "output_dim", None)
564

565
566
567
568
569
570
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

571
572
573
574
575
576
577
578
        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
579
580
            final_shape = list(loaded_weight.shape)
            if output_dim is not None:
581
                assert final_shape[output_dim] % self.tp_size == 0
582
                final_shape[output_dim] = final_shape[output_dim] // self.tp_size
583
            param.materialize(final_shape, dtype=loaded_weight.dtype)
584

585
        param_data = param.data
586
        if output_dim is not None and not is_sharded_weight:
587
            shard_size = param_data.shape[output_dim]
588
            start_idx = self.tp_rank * shard_size
589
            loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
590
591
592
593
594

        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)
595

596
597
598
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

599
    def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor):
600
601
602
603
604
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)
605
606
        param.load_column_parallel_weight(loaded_weight=loaded_weight)

607
    def forward(
608
609
        self,
        input_,
610
    ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
611
612
613
        bias = self.bias if not self.skip_bias_add else None

        # Matrix multiply.
614
        assert self.quant_method is not None
615
        output_parallel = self.quant_method.apply(self, input_, bias)
616

617
        if self.gather_output and self.tp_size > 1:
618
619
620
621
            # All-gather across the partitions.
            output = tensor_model_parallel_all_gather(output_parallel)
        else:
            output = output_parallel
622

623
624
        if not self.return_bias:
            return output
625
        output_bias = self.bias if self.skip_bias_add else None
626
627
        return output, output_bias

628
629
630
631
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size_per_partition}"
        s += f", bias={self.bias is not None}"
632
        s += f", tp_size={self.tp_size}"
633
634
635
        s += f", gather_output={self.gather_output}"
        return s

636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654

class MergedColumnParallelLinear(ColumnParallelLinear):
    """Packed linear layers with column parallelism.

    Similar to ColumnParallelLinear, but the weight matrix is concatenated
    along the output dimension. When the weight matrix is loaded, the
    different partitions are sharded separately.

    Args:
        input_size: input dimension of the linear layer.
        output_sizes: list of output dimensions of the linear layer.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make the output
                       available to all GPUs, otherwise, every GPU will have
                       its own output.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
655
        quant_config: Quantization configure.
656
657
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
658
        return_bias: If true, return bias together with outputs in forward pass.
659
660
        disable_tp: If true, all weights matrix won't be sharded, this layer
                    will be treated as a "Replicated" MergedLinear.
661
662
    """

663
664
665
666
667
668
669
    def __init__(
        self,
        input_size: int,
        output_sizes: list[int],
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
670
671
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
672
673
674
        prefix: str = "",
        *,
        return_bias: bool = True,
675
        disable_tp: bool = False,
676
    ):
677
        self.output_sizes = output_sizes
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0

        assert all(output_size % self.tp_size == 0 for output_size in output_sizes)
        super().__init__(
            input_size=input_size,
            output_size=sum(output_sizes),
            bias=bias,
            gather_output=gather_output,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            quant_config=quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
James Fleming's avatar
James Fleming committed
694

695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
    def validate_shard_id(self, loaded_shard_id: int | tuple[int, ...] | None):
        if loaded_shard_id is None:
            return
        if isinstance(loaded_shard_id, tuple):
            for idx in loaded_shard_id:
                if not (0 <= idx < len(self.output_sizes)):
                    raise ValueError(
                        f"Shard id index {idx} should be between 0 and "
                        f"{len(self.output_sizes) - 1}. Got shard id {loaded_shard_id}."
                    )
            if len(loaded_shard_id) > 1 and any(
                b - a != 1 for a, b in zip(loaded_shard_id[:-1], loaded_shard_id[1:])
            ):
                raise ValueError(
                    "Shard id with multiple indices should be consecutive. "
                    f"Got shard id {loaded_shard_id}."
                )
            return
        elif isinstance(loaded_shard_id, int):
            if loaded_shard_id < 0 or loaded_shard_id >= len(self.output_sizes):
                raise ValueError(
                    f"Shard id should be between 0 and {len(self.output_sizes) - 1}. "
                    f"Got shard id {loaded_shard_id}."
                )
            return
        raise ValueError("This line should not be reached")

722
723
724
725
    def weight_loader(
        self,
        param: Parameter,
        loaded_weight: torch.Tensor,
726
        loaded_shard_id: tuple[int, ...] | int | None = None,
727
    ):
728
        self.validate_shard_id(loaded_shard_id)
729
730
731
732
        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
733
734
735
736
737
738
        if isinstance(loaded_shard_id, tuple) and (
            is_gguf_weight or is_gguf_weight_type
        ):
            raise NotImplementedError(
                "Shard id with multiple indices is not supported for GGUF."
            )
739
        if is_gguf_weight_type:
740
741
742
743
744
            if loaded_shard_id is not None:
                param.data[loaded_shard_id].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
                param.shard_weight_type = {
745
                    i: loaded_weight.item() for i, _ in enumerate(self.output_sizes)
746
                }
747
748
            return

749
750
        if is_gguf_weight:
            output_dim = getattr(param, "output_dim", None)
751
752
            shard_size = loaded_weight.size(output_dim) // self.tp_size
            start_idx = self.tp_rank * shard_size
753

754
            if loaded_shard_id is not None:
755
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
756
757
758
759
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
760

761
762
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
763
764
        # Special case for per-tensor scale to load scalar into fused array.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
765

766
        if loaded_shard_id is None or isinstance(loaded_shard_id, tuple):
767
768
            # Loaded weight is already fused on disk (mlp).
            # (e.g., Phi-3's gate_up_proj).
769
            if output_dim is None:
770
                if needs_scalar_to_array:
771
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
772
773
                        param_data, loaded_weight, 0
                    )
774

775
776
777
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
778
779
780
781
782
783

            output_sizes = (
                self.output_sizes[loaded_shard_id[0] : loaded_shard_id[-1] + 1]
                if loaded_shard_id is not None
                else self.output_sizes
            )
784
            current_shard_offset = 0
785
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
786
787
788
789
790
            if use_bitsandbytes_4bit and isinstance(loaded_shard_id, tuple):
                raise NotImplementedError(
                    "Shard id with multiple indices is not supported "
                    "for BNB quantization yet."
                )
791
            shard_offsets: list[tuple[int, int, int]] = []
792
            for i, output_size in enumerate(output_sizes):
793
794
795
796
                shard_offsets.append((i, current_shard_offset, output_size))
                current_shard_offset += output_size
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
797
                # Special case for Quantization.
798
799
800
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
801
802
                    shard_size = shard_size // param.packed_factor
                    shard_offset = shard_offset // param.packed_factor
803
                    # Special case for Marlin.
804
                    shard_size, shard_offset = adjust_marlin_shard(
805
806
                        param, shard_size, shard_offset
                    )
807

808
                if use_bitsandbytes_4bit:
809
810
811
812
813
814
815
                    index = list(itertools.accumulate([0] + self.output_sizes))
                    orig_offsets = {
                        str(i): (index[i], size)
                        for i, size in enumerate(self.output_sizes)
                    }
                    orig_offsets["total"] = (self.output_size, 0)
                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
816
817
                        param, orig_offsets, str(shard_id)
                    )
818

819
                loaded_weight_shard = loaded_weight.narrow(
820
821
                    output_dim, shard_offset, shard_size
                )
822
823
824
825
826
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id < len(self.output_sizes)
        if output_dim is not None:
827
828
            shard_offset = sum(self.output_sizes[:loaded_shard_id])
            shard_size = self.output_sizes[loaded_shard_id]
829
830
            shard_offset //= self.tp_size
            shard_size //= self.tp_size
831
832
833
834
835
836
837

            if isinstance(param, BlockQuantScaleParameter):
                weight_block_size = getattr(self, "weight_block_size", None)
                shard_size, shard_offset = adjust_block_scale_shard(
                    weight_block_size, shard_size, shard_offset
                )

838
            # Special case for quantization.
839
840
841
842
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
843
844
                shard_size = shard_size // param.packed_factor
                shard_offset = shard_offset // param.packed_factor
845
                # Special case for Marlin.
846
                shard_size, shard_offset = adjust_marlin_shard(
847
848
                    param, shard_size, shard_offset
                )
849

850
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
851
852
853
854
855
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

856
            if use_bitsandbytes_4bit:
857
                shard_size = loaded_weight.shape[output_dim]
858
                shard_offset = loaded_weight.shape[output_dim] * loaded_shard_id
859

860
            param_data = param_data.narrow(output_dim, shard_offset, shard_size)
861
            start_idx = self.tp_rank * shard_size
862
            if not is_sharded_weight:
863
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
864
865
866
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
867
868
                param_data, loaded_weight, loaded_shard_id
            )
869

870
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
871
872
873
874
875
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "MergedColumnParallelLinear, assume the weight is "
876
877
                    "the same for all partitions."
                )
878

879
880
881
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

882
    def _load_fused_module_from_checkpoint(
883
884
885
886
        self,
        param: BasevLLMParameter,
        loaded_weight: torch.Tensor,
        output_sizes: list[int] | None = None,
887
    ):
888
889
890
        """
        Handle special case for models where MLP layers are already
        fused on disk. In this case, we have no shard id. This function
891
        determines the shard id by splitting these layers and then calls
892
893
894
895
896
897
898
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """

        current_shard_offset = 0
899
        shard_offsets: list[tuple[int, int, int]] = []
900
901
        output_sizes = output_sizes or self.output_sizes
        for i, output_size in enumerate(output_sizes):
902
903
904
905
906
907
908
            shard_offsets.append((i, current_shard_offset, output_size))
            current_shard_offset += output_size

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
909
910
911
912
913
914
915
916
917
918
919
            if (
                isinstance(param, (PackedColumnParameter, PackedvLLMParameter))
                and param.packed_dim == param.output_dim
            ):
                shard_size, shard_offset = param.adjust_shard_indexes_for_packing(
                    shard_size=shard_size, shard_offset=shard_offset
                )

            loaded_weight_shard = loaded_weight.narrow(
                param.output_dim, shard_offset, shard_size
            )
920
921
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

922
923
924
925
    def weight_loader_v2(
        self,
        param: BasevLLMParameter,
        loaded_weight: torch.Tensor,
926
        loaded_shard_id: tuple[int, ...] | int | None = None,
927
    ):
928
        self.validate_shard_id(loaded_shard_id)
929
        if loaded_shard_id is None or isinstance(loaded_shard_id, tuple):
930
            if isinstance(param, PerTensorScaleParameter):
931
                param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0)
932
                return
933
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
934
                param.load_merged_column_weight(loaded_weight=loaded_weight)
935
                return
936
937
938
939
940
941
942
943
944
945
946
            output_sizes = (
                [self.output_sizes[idx] for idx in loaded_shard_id]
                if loaded_shard_id
                else None
            )
            if isinstance(param, BlockQuantScaleParameter):
                weight_block_size = getattr(self, "weight_block_size", None)
                output_sizes = [
                    adjust_block_scale_shard(weight_block_size, size, 0)[0]
                    for size in (output_sizes or self.output_sizes)
                ]
947
            # TODO: @dsikka - move to parameter.py
948
949
950
            self._load_fused_module_from_checkpoint(
                param, loaded_weight, output_sizes=output_sizes
            )
951
952
953
954
            return

        assert loaded_shard_id < len(self.output_sizes)

955
956
        shard_offset = sum(self.output_sizes[:loaded_shard_id])
        shard_size = self.output_sizes[loaded_shard_id]
957
958
        shard_offset //= self.tp_size
        shard_size //= self.tp_size
959

960
        if isinstance(param, BlockQuantScaleParameter):
961
962
963
            weight_block_size = getattr(self, "weight_block_size", None)
            shard_size, shard_offset = adjust_block_scale_shard(
                weight_block_size, shard_size, shard_offset
964
            )
965

966
967
968
969
970
971
972
        param.load_merged_column_weight(
            loaded_weight=loaded_weight,
            shard_id=loaded_shard_id,
            shard_offset=shard_offset,
            shard_size=shard_size,
            tp_rank=self.tp_rank,
        )
973

974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995

class QKVParallelLinear(ColumnParallelLinear):
    """Linear layers for the attention's QKV transformation.

    Linear layers for the linear transformation of the query, key, and value
    vectors in the attention layer. The weight matrix is concatenated along
    the output dimension. The layer is parallelized along the head dimension.
    When the number of key/value heads is smaller than the number of query
    heads (e.g., multi-query/grouped-query attention), the key/value head may
    be replicated while the query heads are partitioned.

    Args:
        hidden_size: input hidden state size of the transformer.
        head_size: size of each attention head.
        total_num_heads: total number of attention query heads.
        total_num_kv_heads: total number of attention key/value heads. If
                            None, assume total_num_kv_heads = total_num_heads.
        bias: If true, add bias.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
996
        quant_config: Quantization configure.
997
998
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
999
        return_bias: If true, return bias together with outputs in forward pass.
1000
        disable_tp: If true, weights matrix won't be sharded through tp rank.
1001
1002
    """

1003
1004
1005
1006
1007
    def __init__(
        self,
        hidden_size: int,
        head_size: int,
        total_num_heads: int,
1008
        total_num_kv_heads: int | None = None,
1009
1010
        bias: bool = True,
        skip_bias_add: bool = False,
1011
1012
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
1013
1014
1015
        prefix: str = "",
        *,
        return_bias: bool = True,
1016
        disable_tp: bool = False,
1017
        v_head_size: int | None = None,
1018
    ):
1019
1020
        self.hidden_size = hidden_size
        self.head_size = head_size
1021
        self.v_head_size = v_head_size if v_head_size is not None else head_size
1022
1023
1024
1025
1026
        self.total_num_heads = total_num_heads
        if total_num_kv_heads is None:
            total_num_kv_heads = total_num_heads
        self.total_num_kv_heads = total_num_kv_heads
        # Divide the weight matrix along the last dimension.
1027
        tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
1028
1029
1030
        self.num_heads = divide(self.total_num_heads, tp_size)
        if tp_size >= self.total_num_kv_heads:
            self.num_kv_heads = 1
1031
            self.num_kv_head_replicas = divide(tp_size, self.total_num_kv_heads)
1032
1033
1034
1035
        else:
            self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
            self.num_kv_head_replicas = 1
        input_size = self.hidden_size
1036
        output_size = (
1037
1038
1039
1040
            self.num_heads * self.head_size
            + self.num_kv_heads * self.head_size
            + self.num_kv_heads * self.v_head_size
        ) * tp_size
1041
1042
1043
        self.output_sizes = [
            self.num_heads * self.head_size * tp_size,  # q_proj
            self.num_kv_heads * self.head_size * tp_size,  # k_proj
1044
            self.num_kv_heads * self.v_head_size * tp_size,  # v_proj
James Fleming's avatar
James Fleming committed
1045
1046
        ]

1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
        super().__init__(
            input_size=input_size,
            output_size=output_size,
            bias=bias,
            gather_output=False,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            quant_config=quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
1059

1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
    def validate_shard_id(self, loaded_shard_id: str | None):
        if loaded_shard_id is None:
            return
        if isinstance(loaded_shard_id, str):
            if loaded_shard_id not in ["q", "k", "v"]:
                raise ValueError(
                    "Shard id for QKVParallelLinear should be 'q', 'k', or 'v', "
                    f"got shard id {loaded_shard_id}."
                )
            return
        raise ValueError("This line should not be reached")

1072
1073
1074
1075
1076
    def _get_shard_offset_mapping(self, loaded_shard_id: str):
        shard_offset_mapping = {
            "q": 0,
            "k": self.num_heads * self.head_size,
            "v": (self.num_heads + self.num_kv_heads) * self.head_size,
1077
1078
            "total": (self.num_heads + self.num_kv_heads) * self.head_size
            + self.num_kv_heads * self.v_head_size,
1079
1080
1081
1082
1083
1084
1085
        }
        return shard_offset_mapping.get(loaded_shard_id)

    def _get_shard_size_mapping(self, loaded_shard_id: str):
        shard_size_mapping = {
            "q": self.num_heads * self.head_size,
            "k": self.num_kv_heads * self.head_size,
1086
            "v": self.num_kv_heads * self.v_head_size,
1087
1088
1089
        }
        return shard_size_mapping.get(loaded_shard_id)

1090
1091
1092
    def _load_fused_module_from_checkpoint(
        self, param: BasevLLMParameter, loaded_weight: torch.Tensor
    ):
1093
        """
1094
        Handle special case for models where QKV layers are already
1095
        fused on disk. In this case, we have no shard id. This function
1096
        determines the shard id by splitting these layers and then calls
1097
1098
1099
1100
1101
1102
1103
1104
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """
        shard_offsets = [
            # (shard_id, shard_offset, shard_size)
            ("q", 0, self.total_num_heads * self.head_size),
1105
1106
1107
1108
1109
1110
1111
1112
            (
                "k",
                self.total_num_heads * self.head_size,
                self.total_num_kv_heads * self.head_size,
            ),
            (
                "v",
                (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
1113
                self.total_num_kv_heads * self.v_head_size,
1114
            ),
1115
1116
1117
1118
1119
1120
        ]

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
            if (
                isinstance(param, (PackedColumnParameter, PackedvLLMParameter))
                and param.packed_dim == param.output_dim
            ):
                shard_size, shard_offset = param.adjust_shard_indexes_for_packing(
                    shard_size=shard_size, shard_offset=shard_offset
                )

            loaded_weight_shard = loaded_weight.narrow(
                param.output_dim, shard_offset, shard_size
            )
1132
1133
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

1134
1135
1136
1137
    def weight_loader_v2(
        self,
        param: BasevLLMParameter,
        loaded_weight: torch.Tensor,
1138
        loaded_shard_id: str | None = None,
1139
    ):
1140
        self.validate_shard_id(loaded_shard_id)
1141
        if loaded_shard_id is None:  # special case for certain models
1142
            if isinstance(param, PerTensorScaleParameter):
1143
1144
1145
                param.load_qkv_weight(
                    loaded_weight=loaded_weight, shard_id=0, tp_rank=self.tp_rank
                )
1146
                return
1147
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
1148
                param.load_qkv_weight(loaded_weight=loaded_weight, tp_rank=self.tp_rank)
1149
                return
1150
            # TODO: @dsikka - move to parameter.py
1151
1152
1153
1154
1155
1156
1157
1158
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id in ["q", "k", "v"]

        shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
        shard_size = self._get_shard_size_mapping(loaded_shard_id)

1159
        if isinstance(param, BlockQuantScaleParameter):
1160
1161
1162
1163
            weight_block_size = getattr(self, "weight_block_size", None)
            shard_size, shard_offset = adjust_block_scale_shard(
                weight_block_size, shard_size, shard_offset
            )
1164

1165
1166
1167
1168
1169
1170
1171
1172
        param.load_qkv_weight(
            loaded_weight=loaded_weight,
            num_heads=self.num_kv_head_replicas,
            shard_id=loaded_shard_id,
            shard_offset=shard_offset,
            shard_size=shard_size,
            tp_rank=self.tp_rank,
        )
1173

1174
1175
1176
1177
    def weight_loader(
        self,
        param: Parameter,
        loaded_weight: torch.Tensor,
1178
        loaded_shard_id: str | None = None,
1179
    ):
1180
        self.validate_shard_id(loaded_shard_id)
1181
1182
1183
1184
        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
1185
        if is_gguf_weight_type:
1186
            idx_map = {"q": 0, "k": 1, "v": 2}
1187
1188
1189
1190
            if loaded_shard_id is not None:
                param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
1191
                param.shard_weight_type = {k: loaded_weight.item() for k in idx_map}
1192
1193
            return

1194
1195
        if is_gguf_weight:
            output_dim = getattr(param, "output_dim", None)
1196
1197
            shard_size = loaded_weight.size(output_dim) // self.tp_size
            start_idx = self.tp_rank * shard_size
1198

1199
            if loaded_shard_id is not None:
1200
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
1201
1202
1203
1204
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
1205

1206
1207
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
1208

1209
1210
        # Special case for per-tensor scales in fused case.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
1211

1212
        if loaded_shard_id is None:
1213
1214
            # Loaded weight is already fused on disk (qkv).
            # (e.g., Phi-3's qkv_proj).
1215
            if output_dim is None:
1216
                if needs_scalar_to_array:
1217
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
1218
1219
                        param_data, loaded_weight, 0
                    )
1220

1221
1222
1223
1224
1225
1226
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            shard_offsets = [
                # (shard_id, shard_offset, shard_size)
                ("q", 0, self.total_num_heads * self.head_size),
1227
1228
1229
1230
1231
1232
1233
1234
                (
                    "k",
                    self.total_num_heads * self.head_size,
                    self.total_num_kv_heads * self.head_size,
                ),
                (
                    "v",
                    (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
1235
                    self.total_num_kv_heads * self.v_head_size,
1236
                ),
1237
            ]
1238
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1239

1240
1241
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
1242
                # Special case for Quantized Weights.
1243
1244
1245
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
1246
1247
                    shard_size = shard_size // param.packed_factor
                    shard_offset = shard_offset // param.packed_factor
1248

1249
                    # Special case for Marlin.
1250
                    shard_size, shard_offset = adjust_marlin_shard(
1251
1252
                        param, shard_size, shard_offset
                    )
1253

1254
1255
1256
                if use_bitsandbytes_4bit:
                    orig_qkv_offsets = {
                        "q": (0, self.total_num_heads * self.head_size),
1257
1258
1259
1260
1261
1262
1263
                        "k": (
                            self.total_num_heads * self.head_size,
                            self.total_num_kv_heads * self.head_size,
                        ),
                        "v": (
                            (self.total_num_heads + self.total_num_kv_heads)
                            * self.head_size,
1264
                            self.total_num_kv_heads * self.v_head_size,
1265
1266
                        ),
                        "total": (
1267
1268
1269
                            (self.total_num_heads + self.total_num_kv_heads)
                            * self.head_size
                            + self.total_num_kv_heads * self.v_head_size,
1270
1271
                            0,
                        ),
1272
1273
1274
                    }

                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
1275
1276
                        param, orig_qkv_offsets, shard_id
                    )
1277

1278
                loaded_weight_shard = loaded_weight.narrow(
1279
1280
                    output_dim, shard_offset, shard_size
                )
1281
1282
1283
1284
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id in ["q", "k", "v"]
1285
1286

        # If output dim is defined, use the default loading process.
1287
1288
1289
1290
1291
1292
1293
1294
        if output_dim is not None:
            if loaded_shard_id == "q":
                shard_offset = 0
                shard_size = self.num_heads * self.head_size
            elif loaded_shard_id == "k":
                shard_offset = self.num_heads * self.head_size
                shard_size = self.num_kv_heads * self.head_size
            elif loaded_shard_id == "v":
1295
                shard_offset = (self.num_heads + self.num_kv_heads) * self.head_size
1296
                shard_size = self.num_kv_heads * self.v_head_size
1297
1298
1299
1300
1301
1302
1303

            if isinstance(param, BlockQuantScaleParameter):
                weight_block_size = getattr(self, "weight_block_size", None)
                shard_size, shard_offset = adjust_block_scale_shard(
                    weight_block_size, shard_size, shard_offset
                )

1304
            # Special case for Quantized Weights.
1305
1306
1307
1308
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
1309
1310
                shard_size = shard_size // param.packed_factor
                shard_offset = shard_offset // param.packed_factor
1311

1312
                # Special case for Marlin.
1313
                shard_size, shard_offset = adjust_marlin_shard(
1314
1315
                    param, shard_size, shard_offset
                )
1316

1317
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1318
1319
1320
1321
1322
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

1323
            if use_bitsandbytes_4bit:
1324
1325
                orig_qkv_offsets = {
                    "q": (0, self.num_heads * self.head_size),
1326
1327
1328
1329
1330
1331
                    "k": (
                        self.num_heads * self.head_size,
                        self.num_kv_heads * self.head_size,
                    ),
                    "v": (
                        (self.num_heads + self.num_kv_heads) * self.head_size,
1332
                        self.num_kv_heads * self.v_head_size,
1333
1334
                    ),
                    "total": (
1335
1336
                        (self.num_heads + self.num_kv_heads) * self.head_size
                        + self.num_kv_heads * self.v_head_size,
1337
1338
                        0,
                    ),
1339
                }
1340
                shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
1341
1342
                    param, orig_qkv_offsets, loaded_shard_id
                )
1343

1344
            param_data = param_data.narrow(output_dim, shard_offset, shard_size)
1345
            if loaded_shard_id == "q":
1346
                shard_rank = self.tp_rank
1347
            else:
1348
1349
                shard_rank = self.tp_rank // self.num_kv_head_replicas
            start_idx = shard_rank * shard_size
1350

1351
            if not is_sharded_weight:
1352
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
1353

1354
1355
1356
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
1357
1358
                param_data, loaded_weight, loaded_shard_id
            )
1359
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
1360
1361
1362
1363
1364
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "QKVParallelLinear, assume the weight is the same "
1365
1366
                    "for all partitions."
                )
1367

1368
1369
1370
1371
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)


1372
# --8<-- [start:row_parallel_linear]
1373
@PluggableLayer.register("row_parallel_linear")
1374
class RowParallelLinear(LinearBase):
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
    """Linear layer with row parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its first dimension and X along its second dimension as:
               -   -
              | A_1 |
              | .   |
          A = | .   |        X = [X_1, ..., X_p]
              | .   |
              | A_p |
               -   -
    Arguments:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias. Note that bias is not parallelized.
        input_is_parallel: If true, we assume that the input is already
                           split across the GPUs and we do not split
                           again.
        skip_bias_add: This was added to enable performance optimization where
                       bias can be fused with other element-wise operations.
                       We skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
1397
1398
1399
        reduce_results: If true, call all-reduce on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y = X_iA_i
1400
        quant_config: Quantization configure.
1401
1402
1403
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.down_proj)
        return_bias: If true, return bias together with outputs in forward pass.
1404
        disable_tp: If true, weights matrix won't be sharded through tp rank.
1405
1406
    """

1407
1408
    # --8<-- [end:row_parallel_linear]

1409
1410
1411
1412
1413
1414
1415
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        input_is_parallel: bool = True,
        skip_bias_add: bool = False,
1416
        params_dtype: torch.dtype | None = None,
1417
        reduce_results: bool = True,
1418
        quant_config: QuantizationConfig | None = None,
1419
1420
1421
        prefix: str = "",
        *,
        return_bias: bool = True,
1422
        disable_tp: bool = False,
1423
    ):
1424
        # Divide the weight matrix along the first dimension.
1425
1426
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
1427
1428
1429
1430
        self.input_size_per_partition = divide(input_size, self.tp_size)
        self.output_size_per_partition = output_size
        self.output_partition_sizes = [output_size]

1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
        super().__init__(
            input_size,
            output_size,
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
1441

1442
1443
1444
        self.input_is_parallel = input_is_parallel
        self.reduce_results = reduce_results

1445
        assert self.quant_method is not None
1446
1447
1448
        self.quant_method.create_weights(
            layer=self,
            input_size_per_partition=self.input_size_per_partition,
1449
            output_partition_sizes=self.output_partition_sizes,
1450
1451
1452
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
1453
            weight_loader=(
1454
1455
1456
1457
1458
                self.weight_loader_v2
                if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
                else self.weight_loader
            ),
        )
1459
        if not reduce_results and (bias and not skip_bias_add):
1460
1461
1462
1463
            raise ValueError(
                "When not reduce the results, adding bias to the "
                "results can lead to incorrect results"
            )
1464
1465

        if bias:
1466
1467
1468
1469
1470
1471
1472
1473
            self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype))
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
1474
1475
        else:
            self.register_parameter("bias", None)
1476
        self.update_param_tp_status()
1477
1478
1479

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        input_dim = getattr(param, "input_dim", None)
1480
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1481
1482
1483
1484
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
1485
1486
1487
1488
1489
1490
1491
1492
1493

        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
1494
1495
            weight_shape = list(loaded_weight.shape)
            if input_dim:
1496
                weight_shape[input_dim] = weight_shape[input_dim] // self.tp_size
1497
            param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
1498

1499
        param_data = param.data
1500
        if input_dim is not None and not is_sharded_weight:
1501
            shard_size = param_data.shape[input_dim]
1502
            start_idx = self.tp_rank * shard_size
1503
            loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)
1504

1505
1506
1507
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
1508
1509
            loaded_weight = loaded_weight.reshape(1)

1510
1511
1512
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

1513
    def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor):
1514
1515
1516
1517
1518
1519
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)

1520
1521
        param.load_row_parallel_weight(loaded_weight=loaded_weight)

1522
    def forward(
1523
1524
        self,
        input_,
1525
    ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
1526
1527
1528
1529
        if self.input_is_parallel:
            input_parallel = input_
        else:
            splitted_input = split_tensor_along_last_dim(
1530
1531
                input_, num_partitions=self.tp_size
            )
1532
            input_parallel = splitted_input[self.tp_rank].contiguous()
1533
1534

        # Matrix multiply.
1535
        assert self.quant_method is not None
1536
1537
1538
        # Only fuse bias add into GEMM for rank 0 (this ensures that
        # bias will not get added more than once in TP>1 case)
        bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
1539
1540
        output_parallel = self.quant_method.apply(self, input_parallel, bias_)

1541
        if self.reduce_results and self.tp_size > 1:
1542
            output = tensor_model_parallel_all_reduce(output_parallel)
1543
        else:
1544
1545
            output = output_parallel

1546
1547
        if not self.return_bias:
            return output
1548
        output_bias = self.bias if self.skip_bias_add else None
1549
        return output, output_bias
1550
1551

    def extra_repr(self) -> str:
1552
        s = f"in_features={self.input_size_per_partition}"
1553
1554
1555
1556
1557
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={self.tp_size}"
        s += f", reduce_results={self.reduce_results}"
        return s