linear.py 65.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
from abc import abstractmethod
6
from typing import Any, Literal, Optional, Union
7
8

import torch
9
import torch.nn as nn
10
from torch.nn.parameter import Parameter, UninitializedParameter
11

12
13
14
15
16
17
18
19
from vllm.distributed import (
    divide,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
    split_tensor_along_last_dim,
    tensor_model_parallel_all_gather,
    tensor_model_parallel_all_reduce,
)
20
from vllm.logger import init_logger
21
from vllm.model_executor.custom_op import CustomOp
22
from vllm.model_executor.layers.quantization.base_config import (
23
24
25
    QuantizationConfig,
    QuantizeMethodBase,
)
26
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
27

28
# yapf: disable
29
30
31
32
33
34
35
36
37
38
from vllm.model_executor.parameter import (
    BasevLLMParameter,
    BlockQuantScaleParameter,
    ModelWeightParameter,
    PackedColumnParameter,
    PackedvLLMParameter,
    PerTensorScaleParameter,
    RowvLLMParameter,
)

39
# yapf: enable
40
from vllm.model_executor.utils import set_weight_attrs
41
from vllm.platforms import current_platform
42
from vllm.utils import GiB_bytes
43
44
45

logger = init_logger(__name__)

46
WEIGHT_LOADER_V2_SUPPORTED = [
47
    "UnquantizedLinearMethod",
48
    "CompressedTensorsLinearMethod",
49
    "CompressedTensorsLinearTransformMethod",
50
51
    "BitBLASLinearMethod",
    "GPTQBitBLASLinearMethod",
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
    "AWQMarlinLinearMethod",
    "AWQLinearMethod",
    "GPTQMarlinLinearMethod",
    "Fp8LinearMethod",
    "MarlinLinearMethod",
    "GPTQMarlin24LinearMethod",
    "TPUInt8LinearMethod",
    "GPTQLinearMethod",
    "FBGEMMFp8LinearMethod",
    "ModelOptFp8LinearMethod",
    "IPEXAWQLinearMethod",
    "IPEXGPTQLinearMethod",
    "HQQMarlinMethod",
    "QuarkLinearMethod",
    "ModelOptNvFp4LinearMethod",
67
    "PetitNvFp4LinearMethod",
68
]
69

70

71
72
73
def adjust_bitblas_shard(param, shard_size, shard_offset):
    bitblas_tile_size = getattr(param, "bitblas_tile_size", None)
    if bitblas_tile_size is not None:
74
        return (shard_size // bitblas_tile_size, shard_offset // bitblas_tile_size)
75
76
77
78

    return shard_size, shard_offset


79
80
81
82
83
84
85
86
def adjust_marlin_shard(param, shard_size, shard_offset):
    marlin_tile_size = getattr(param, "marlin_tile_size", None)
    if marlin_tile_size is None:
        return shard_size, shard_offset

    return shard_size * marlin_tile_size, shard_offset * marlin_tile_size


87
88
89
def adjust_bitsandbytes_4bit_shard(
    param: Parameter, shard_offsets: dict[str, tuple[int, int]], loaded_shard_id: str
) -> tuple[int, int]:
90
91
    """Adjust the quantization offsets and sizes for BitsAndBytes sharding."""

92
93
    total, _ = shard_offsets["total"]
    orig_offset, orig_size = shard_offsets[loaded_shard_id]
94
95
96
97
98
99
100
101

    quantized_total = param.data.shape[0]
    quantized_offset = orig_offset * quantized_total // total
    quantized_size = orig_size * quantized_total // total

    return quantized_size, quantized_offset


102
103
104
def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
    """For fused modules (QKV and MLP) we have an array of length
    N that holds 1 scale for each "logical" matrix. So the param
105
106
    is an array of length N. The loaded_weight corresponds to
    one of the shards on disk. Here, we slice the param based on
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
    the shard_id for loading.
    """
    qkv_idxs = {"q": 0, "k": 1, "v": 2}

    if isinstance(shard_id, str):
        shard_id = qkv_idxs[shard_id]
    elif not isinstance(shard_id, int):
        raise ValueError(f"Unknown Shard Id {shard_id}")

    # AutoFP8 scales do not have a shape
    # compressed-tensors scales do have a shape
    if len(loaded_weight.shape) != 0:
        assert loaded_weight.shape[0] == 1
        loaded_weight = loaded_weight[0]

    return param[shard_id], loaded_weight


125
126
127
128
129
130
131
132
# TODO(Isotr0py): We might need a more flexible structure to handle
# bitsandbytes shard offsets.
def left_shift_bitsandbytes_4bit_shard(bnb_weight_attrs: dict[str, Any]):
    """
    Separate the BitsAndBytes 4-bit shard.

    For example, given bnb weight attributes as below:
    {
133
        'bnb_shard_offsets': array([0, 4, 8, 16]),
134
135
136
137
138
        'bnb_quant_state': {0: ..., 1: ..., 2: ...},
    }

    The function will return:
    {
139
        'bnb_shard_offsets': array([0, 4]),
140
141
142
143
144
145
146
147
148
149
150
151
152
153
        'bnb_quant_state': {0: ...},
    }
    and
    {
        'bnb_shard_offsets': array([0, 4, 12]),
        'bnb_quant_state': {0: ..., 1: ...},
    }
    """
    shard_offsets = bnb_weight_attrs["bnb_shard_offsets"]
    offset_l = shard_offsets[:2]
    offset_r = shard_offsets[1:] - shard_offsets[1]
    quant_state_l = {0: bnb_weight_attrs["bnb_quant_state"][0]}
    quant_state_r = {
        i - 1: bnb_weight_attrs["bnb_quant_state"][i]
154
        for i in range(1, len(shard_offsets) - 1)
155
156
157
158
159
160
    }
    left = dict(bnb_shard_offsets=offset_l, bnb_quant_state=quant_state_l)
    right = dict(bnb_shard_offsets=offset_r, bnb_quant_state=quant_state_r)
    return left, right


161
class LinearMethodBase(QuantizeMethodBase):
162
163
164
    """Base class for different (maybe quantized) linear methods."""

    @abstractmethod
165
166
167
168
169
170
171
172
173
174
175
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        """Create weights for a linear layer.
176
           The weights will be set as attributes of the layer.
177

178
179
180
        Args:
            layer: The layer that is using the LinearMethodBase factory.
            input_size_per_partition: Size of the weight input dim on rank X.
181
            output_partition_sizes: Sizes of the output dim of each logical
182
183
184
185
186
187
                weight on rank X. E.g., output_partition_sizes for QKVLinear
                is a list contains the width of Wq, Wk, Wv on rank X.
            input_size: Size of the input dim of the weight across all ranks.
            output_size: Size of the output dim of the weight across all ranks.
            params_dtype: Datatype of the parameters.
        """
188
189
190
        raise NotImplementedError

    @abstractmethod
191
192
193
194
195
196
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
197
198
        """Apply the weights in layer to the input tensor.
        Expects create_weights to have been called before on the layer."""
199
200
201
202
        raise NotImplementedError


class UnquantizedLinearMethod(LinearMethodBase):
203
    """Linear method without quantization."""
204

205
206
207
208
209
210
211
212
213
214
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
215
216
217
218
219
        # This method creates unquantized linear weights.
        # The weights are not quantized, and they are not sharded.
        # The amount of memory allocated for the weights is
        # sum(output_partition_sizes) * input_size_per_partition.
        try:
220
            weight_loader = extra_weight_attrs.pop("weight_loader")
221
222
223
224
225
226
227
228
229
230
            weight = ModelWeightParameter(
                data=torch.empty(
                    sum(output_partition_sizes),
                    input_size_per_partition,
                    dtype=params_dtype,
                ),
                input_dim=1,
                output_dim=0,
                weight_loader=weight_loader,
            )
231
232
233
234
        except torch.cuda.OutOfMemoryError as e:
            logger.error("Failed to create unquantized linear weights: %s", e)
            if torch.cuda.is_available():
                logger.debug("CUDA device: %s", torch.cuda.current_device())
235
236
237
238
239
240
                logger.debug(
                    "Allocated: %.2f GiB", torch.cuda.memory_allocated() / GiB_bytes
                )
                logger.debug(
                    "Reserved: %.2f GiB", torch.cuda.memory_reserved() / GiB_bytes
                )
241
242
243
            raise RuntimeError(
                "Failed to create unquantized linear weights. "
                "This may be caused by insufficient memory to allocate "
244
245
                "the weight."
            ) from e
246

247
248
        layer.register_parameter("weight", weight)
        set_weight_attrs(weight, extra_weight_attrs)
249

250
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
251
        if current_platform.is_cpu():
252
            from vllm.model_executor.layers.utils import dispatch_cpu_unquantized_gemm
253

254
            dispatch_cpu_unquantized_gemm(layer, remove_weight=True)
255

256
257
258
259
260
261
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
262
        return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
263
264


265
class LinearBase(CustomOp):
266
    """Base linear layer.
267
268
269
270
271
272

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
273
        quant_config: Quantization configure.
274
        prefix: Prefix for parameter names.
275
        return_bias: If true, return bias together with outputs in forward pass.
276
        disable_tp: If true, tensor parallelism will be disabled for this layer.
277
278
279
280
281
282
283
284
    """

    def __init__(
        self,
        input_size: int,
        output_size: int,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
285
        quant_config: Optional[QuantizationConfig] = None,
286
        prefix: str = "",
287
288
        *,
        return_bias: bool = True,
289
        disable_tp: bool = False,
290
291
292
293
294
295
296
297
298
299
    ):
        super().__init__()

        # Keep input parameters
        self.input_size = input_size
        self.output_size = output_size
        self.skip_bias_add = skip_bias_add
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype
300
301
        self.quant_config = quant_config
        self.prefix = prefix
302
        if quant_config is None:
303
            self.quant_method: Optional[QuantizeMethodBase] = UnquantizedLinearMethod()
304
        else:
305
            self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
306
        self.return_bias = return_bias
307
        self.disable_tp = disable_tp
308
309
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
310

311
    def update_param_tp_status(self):
312
313
314
315
        for param in self.parameters():
            if isinstance(param, BasevLLMParameter):
                param.tp_rank = self.tp_rank
                param.tp_size = self.tp_size
316
317


318
@CustomOp.register("replicated_linear")
319
320
321
322
323
324
325
326
327
328
class ReplicatedLinear(LinearBase):
    """Replicated linear layer.

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        bias: If true, add bias.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
        quant_config: Quantization configure.
329
330
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
331
        return_bias: If true, return bias together with outputs in forward pass.
332
        disable_tp: Take no effect for replicated linear layers.
333
334
    """

335
336
337
338
339
340
341
342
343
344
345
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        *,
        return_bias: bool = True,
346
        disable_tp: bool = False,
347
    ):
348
349
350
351
352
353
        # If MergedReplicatedLinear, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = self.output_sizes
        else:
            self.output_partition_sizes = [output_size]

354
355
356
357
358
359
360
361
362
363
        super().__init__(
            input_size,
            output_size,
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
364

365
366
        # All the linear layer supports quant method.
        assert self.quant_method is not None
367
368
369
370
371
372
373
374
375
        self.quant_method.create_weights(
            self,
            self.input_size,
            self.output_partition_sizes,
            self.input_size,
            self.output_size,
            self.params_dtype,
            weight_loader=self.weight_loader,
        )
376

377
378
        if bias:
            self.bias = Parameter(
379
380
381
382
383
384
385
386
387
                torch.empty(self.output_size, dtype=self.params_dtype)
            )
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
388
389
390
        else:
            self.register_parameter("bias", None)

391
392
393
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        # If the weight on disk does not have a shape, give it one
        # (such scales for AutoFp8).
394
395
396
397
398
399
400
401
402
403
404
        # Special case for GGUF

        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
            param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)

405
406
407
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)

408
409
        assert param.size() == loaded_weight.size(), (
            f"Tried to load weights of size {loaded_weight.size()}"
410
411
            f"to a parameter of size {param.size()}"
        )
412
413
        param.data.copy_(loaded_weight)

414
    def forward(
415
416
        self,
        x: torch.Tensor,
417
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
418
        bias = self.bias if not self.skip_bias_add else None
419
        assert self.quant_method is not None
420

421
        output = self.quant_method.apply(self, x, bias)
422
        output_bias = self.bias if self.skip_bias_add else None
423

424
425
        if not self.return_bias:
            return output
426
427
        return output, output_bias

428
429
430
431
432
433
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        return s

434

435
@CustomOp.register("column_parallel_linear")
436
class ColumnParallelLinear(LinearBase):
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
    """Linear layer with column parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its second dimension as A = [A_1, ..., A_p].

    Args:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y_i = XA_i
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
453
        quant_config: Quantization configure.
James Fleming's avatar
James Fleming committed
454
455
        output_sizes: list of output sizes packed into one output, like for QKV
                       the list would be size 3.
456
        prefix: The name of the layer in the state dict, including all parents
457
                        (e.g. model.layers.0.qkv_proj)
458
459
        return_bias: If true, return bias together with outputs in forward pass.
        disable_tp: If true, weights matrix won't be sharded through tp rank.
460
461
    """

462
463
464
465
466
467
468
469
470
471
472
473
474
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
        output_sizes: Optional[list[int]] = None,
        prefix: str = "",
        *,
        return_bias: bool = True,
475
        disable_tp: bool = False,
476
    ):
477
        # Divide the weight matrix along the last dimension.
478
479
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
480
481
        self.input_size_per_partition = input_size
        self.output_size_per_partition = divide(output_size, self.tp_size)
482
483
484
485
        self.output_partition_sizes = [self.output_size_per_partition]
        # If QKV or MergedColumn, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = [
486
                divide(output_size, self.tp_size) for output_size in self.output_sizes
487
488
            ]

489
490
491
492
493
494
495
496
497
498
        super().__init__(
            input_size,
            output_size,
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
499
500
501

        self.gather_output = gather_output

James Fleming's avatar
James Fleming committed
502
503
        if output_sizes is None:
            output_sizes = [output_size]
504

505
        assert self.quant_method is not None
506
507
        self.quant_method.create_weights(
            layer=self,
508
            input_size_per_partition=self.input_size_per_partition,
509
510
511
512
            output_partition_sizes=self.output_partition_sizes,
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
513
            weight_loader=(
514
515
516
517
518
                self.weight_loader_v2
                if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
                else self.weight_loader
            ),
        )
519
520
        if bias:
            self.bias = Parameter(
521
522
523
524
525
526
527
528
529
                torch.empty(self.output_size_per_partition, dtype=params_dtype)
            )
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
530
531
        else:
            self.register_parameter("bias", None)
532
        self.update_param_tp_status()
533
534
535

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        output_dim = getattr(param, "output_dim", None)
536

537
538
539
540
541
542
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

543
544
545
546
547
548
549
550
        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
551
552
            final_shape = list(loaded_weight.shape)
            if output_dim is not None:
553
                assert final_shape[output_dim] % self.tp_size == 0
554
                final_shape[output_dim] = final_shape[output_dim] // self.tp_size
555
            param.materialize(final_shape, dtype=loaded_weight.dtype)
556

557
        param_data = param.data
558
        if output_dim is not None and not is_sharded_weight:
559
            shard_size = param_data.shape[output_dim]
560
            start_idx = self.tp_rank * shard_size
561
            loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
562
563
564
565
566

        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)
567

568
569
570
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

571
    def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor):
572
573
574
575
576
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)
577
578
        param.load_column_parallel_weight(loaded_weight=loaded_weight)

579
    def forward(
580
581
        self,
        input_,
582
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
583
584
585
        bias = self.bias if not self.skip_bias_add else None

        # Matrix multiply.
586
        assert self.quant_method is not None
587
        output_parallel = self.quant_method.apply(self, input_, bias)
588

589
        if self.gather_output and self.tp_size > 1:
590
591
592
593
594
            # All-gather across the partitions.
            output = tensor_model_parallel_all_gather(output_parallel)
        else:
            output = output_parallel
        output_bias = self.bias if self.skip_bias_add else None
595
596
        if not self.return_bias:
            return output
597
598
        return output, output_bias

599
600
601
602
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size_per_partition}"
        s += f", bias={self.bias is not None}"
603
        s += f", tp_size={self.tp_size}"
604
605
606
        s += f", gather_output={self.gather_output}"
        return s

607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625

class MergedColumnParallelLinear(ColumnParallelLinear):
    """Packed linear layers with column parallelism.

    Similar to ColumnParallelLinear, but the weight matrix is concatenated
    along the output dimension. When the weight matrix is loaded, the
    different partitions are sharded separately.

    Args:
        input_size: input dimension of the linear layer.
        output_sizes: list of output dimensions of the linear layer.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make the output
                       available to all GPUs, otherwise, every GPU will have
                       its own output.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
626
        quant_config: Quantization configure.
627
628
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
629
        return_bias: If true, return bias together with outputs in forward pass.
630
631
        disable_tp: If true, all weights matrix won't be sharded, this layer
                    will be treated as a "Replicated" MergedLinear.
632
633
    """

634
635
636
637
638
639
640
641
642
643
644
645
    def __init__(
        self,
        input_size: int,
        output_sizes: list[int],
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        *,
        return_bias: bool = True,
646
        disable_tp: bool = False,
647
    ):
648
        self.output_sizes = output_sizes
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0

        assert all(output_size % self.tp_size == 0 for output_size in output_sizes)
        super().__init__(
            input_size=input_size,
            output_size=sum(output_sizes),
            bias=bias,
            gather_output=gather_output,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            quant_config=quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
James Fleming's avatar
James Fleming committed
665

666
667
668
669
670
671
    def weight_loader(
        self,
        param: Parameter,
        loaded_weight: torch.Tensor,
        loaded_shard_id: Optional[int] = None,
    ):
672
673
674
675
676
        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
677
678
679
680
681
            if loaded_shard_id is not None:
                param.data[loaded_shard_id].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
                param.shard_weight_type = {
682
                    i: loaded_weight.item() for i, _ in enumerate(self.output_sizes)
683
                }
684
685
            return

686
687
        if is_gguf_weight:
            output_dim = getattr(param, "output_dim", None)
688
689
            shard_size = loaded_weight.size(output_dim) // self.tp_size
            start_idx = self.tp_rank * shard_size
690

691
            if loaded_shard_id is not None:
692
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
693
694
695
696
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
697

698
699
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
700
701
        # Special case for per-tensor scale to load scalar into fused array.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
702

703
        if loaded_shard_id is None:
704
705
            # Loaded weight is already fused on disk (mlp).
            # (e.g., Phi-3's gate_up_proj).
706
            if output_dim is None:
707
                if needs_scalar_to_array:
708
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
709
710
                        param_data, loaded_weight, 0
                    )
711

712
713
714
715
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            current_shard_offset = 0
716
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
717
            shard_offsets: list[tuple[int, int, int]] = []
718
719
720
721
722
            for i, output_size in enumerate(self.output_sizes):
                shard_offsets.append((i, current_shard_offset, output_size))
                current_shard_offset += output_size
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
723
                # Special case for Quantization.
724
725
726
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
727
728
                    shard_size = shard_size // param.packed_factor
                    shard_offset = shard_offset // param.packed_factor
729
                    # Special case for Marlin.
730
                    shard_size, shard_offset = adjust_marlin_shard(
731
732
                        param, shard_size, shard_offset
                    )
733

734
                shard_size, shard_offset = adjust_bitblas_shard(
735
736
                    param, shard_size, shard_offset
                )
737

738
                if use_bitsandbytes_4bit:
739
740
741
742
743
744
745
                    index = list(itertools.accumulate([0] + self.output_sizes))
                    orig_offsets = {
                        str(i): (index[i], size)
                        for i, size in enumerate(self.output_sizes)
                    }
                    orig_offsets["total"] = (self.output_size, 0)
                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
746
747
                        param, orig_offsets, str(shard_id)
                    )
748

749
                loaded_weight_shard = loaded_weight.narrow(
750
751
                    output_dim, shard_offset, shard_size
                )
752
753
754
755
756
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id < len(self.output_sizes)
        if output_dim is not None:
757
            shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
758
            shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
759
            # Special case for quantization.
760
761
762
763
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
764
765
                shard_size = shard_size // param.packed_factor
                shard_offset = shard_offset // param.packed_factor
766
                # Special case for Marlin.
767
                shard_size, shard_offset = adjust_marlin_shard(
768
769
                    param, shard_size, shard_offset
                )
770
            shard_size, shard_offset = adjust_bitblas_shard(
771
772
                param, shard_size, shard_offset
            )
773

774
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
775
776
777
778
779
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

780
            if use_bitsandbytes_4bit:
781
                shard_size = loaded_weight.shape[output_dim]
782
                shard_offset = loaded_weight.shape[output_dim] * loaded_shard_id
783

784
            param_data = param_data.narrow(output_dim, shard_offset, shard_size)
785
            start_idx = self.tp_rank * shard_size
786
            if not is_sharded_weight:
787
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
788
789
790
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
791
792
                param_data, loaded_weight, loaded_shard_id
            )
793

794
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
795
796
797
798
799
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "MergedColumnParallelLinear, assume the weight is "
800
801
                    "the same for all partitions."
                )
802

803
804
805
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

806
807
808
    def _load_fused_module_from_checkpoint(
        self, param: BasevLLMParameter, loaded_weight: torch.Tensor
    ):
809
810
811
        """
        Handle special case for models where MLP layers are already
        fused on disk. In this case, we have no shard id. This function
812
        determines the shard id by splitting these layers and then calls
813
814
815
816
817
818
819
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """

        current_shard_offset = 0
820
        shard_offsets: list[tuple[int, int, int]] = []
821
822
823
824
825
826
827
828
        for i, output_size in enumerate(self.output_sizes):
            shard_offsets.append((i, current_shard_offset, output_size))
            current_shard_offset += output_size

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
829
830
831
832
833
834
835
836
837
838
839
            if (
                isinstance(param, (PackedColumnParameter, PackedvLLMParameter))
                and param.packed_dim == param.output_dim
            ):
                shard_size, shard_offset = param.adjust_shard_indexes_for_packing(
                    shard_size=shard_size, shard_offset=shard_offset
                )

            loaded_weight_shard = loaded_weight.narrow(
                param.output_dim, shard_offset, shard_size
            )
840
841
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

842
843
844
845
846
847
    def weight_loader_v2(
        self,
        param: BasevLLMParameter,
        loaded_weight: torch.Tensor,
        loaded_shard_id: Optional[int] = None,
    ):
848
        if loaded_shard_id is None:
849
            if isinstance(param, PerTensorScaleParameter):
850
                param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0)
851
                return
852
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
853
                param.load_merged_column_weight(loaded_weight=loaded_weight)
854
                return
855
            # TODO: @dsikka - move to parameter.py
856
857
858
859
860
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id < len(self.output_sizes)

861
862
        if isinstance(param, BlockQuantScaleParameter):
            assert self.quant_method is not None
863
864
865
            # Assume the weight block size has been set by quant method
            assert hasattr(self, "weight_block_size")
            weight_block_size = self.weight_block_size
866
867
868
            assert weight_block_size is not None
            block_n, _ = weight_block_size[0], weight_block_size[1]
            shard_offset = (
869
870
871
872
873
874
875
                (sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) // block_n
            ) // self.tp_size
            shard_size = (
                (self.output_sizes[loaded_shard_id] + block_n - 1)
                // block_n
                // self.tp_size
            )
876
        else:
877
            shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
878
            shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
879

880
881
882
883
884
885
886
        param.load_merged_column_weight(
            loaded_weight=loaded_weight,
            shard_id=loaded_shard_id,
            shard_offset=shard_offset,
            shard_size=shard_size,
            tp_rank=self.tp_rank,
        )
887

888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909

class QKVParallelLinear(ColumnParallelLinear):
    """Linear layers for the attention's QKV transformation.

    Linear layers for the linear transformation of the query, key, and value
    vectors in the attention layer. The weight matrix is concatenated along
    the output dimension. The layer is parallelized along the head dimension.
    When the number of key/value heads is smaller than the number of query
    heads (e.g., multi-query/grouped-query attention), the key/value head may
    be replicated while the query heads are partitioned.

    Args:
        hidden_size: input hidden state size of the transformer.
        head_size: size of each attention head.
        total_num_heads: total number of attention query heads.
        total_num_kv_heads: total number of attention key/value heads. If
                            None, assume total_num_kv_heads = total_num_heads.
        bias: If true, add bias.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
910
        quant_config: Quantization configure.
911
912
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
913
        return_bias: If true, return bias together with outputs in forward pass.
914
        disable_tp: If true, weights matrix won't be sharded through tp rank.
915
916
    """

917
918
919
920
921
922
923
924
925
926
927
928
929
    def __init__(
        self,
        hidden_size: int,
        head_size: int,
        total_num_heads: int,
        total_num_kv_heads: Optional[int] = None,
        bias: bool = True,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        *,
        return_bias: bool = True,
930
        disable_tp: bool = False,
931
    ):
932
933
934
935
936
937
938
        self.hidden_size = hidden_size
        self.head_size = head_size
        self.total_num_heads = total_num_heads
        if total_num_kv_heads is None:
            total_num_kv_heads = total_num_heads
        self.total_num_kv_heads = total_num_kv_heads
        # Divide the weight matrix along the last dimension.
939
        tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
940
941
942
        self.num_heads = divide(self.total_num_heads, tp_size)
        if tp_size >= self.total_num_kv_heads:
            self.num_kv_heads = 1
943
            self.num_kv_head_replicas = divide(tp_size, self.total_num_kv_heads)
944
945
946
947
        else:
            self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
            self.num_kv_head_replicas = 1
        input_size = self.hidden_size
948
949
950
        output_size = (
            (self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size
        )
951
952
953
        self.output_sizes = [
            self.num_heads * self.head_size * tp_size,  # q_proj
            self.num_kv_heads * self.head_size * tp_size,  # k_proj
954
            self.num_kv_heads * self.head_size * tp_size,  # v_proj
James Fleming's avatar
James Fleming committed
955
956
        ]

957
958
959
960
961
962
963
964
965
966
967
968
        super().__init__(
            input_size=input_size,
            output_size=output_size,
            bias=bias,
            gather_output=False,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            quant_config=quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
969

970
971
972
973
974
    def _get_shard_offset_mapping(self, loaded_shard_id: str):
        shard_offset_mapping = {
            "q": 0,
            "k": self.num_heads * self.head_size,
            "v": (self.num_heads + self.num_kv_heads) * self.head_size,
975
            "total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size,
976
977
978
979
980
981
982
983
984
985
986
        }
        return shard_offset_mapping.get(loaded_shard_id)

    def _get_shard_size_mapping(self, loaded_shard_id: str):
        shard_size_mapping = {
            "q": self.num_heads * self.head_size,
            "k": self.num_kv_heads * self.head_size,
            "v": self.num_kv_heads * self.head_size,
        }
        return shard_size_mapping.get(loaded_shard_id)

987
988
989
    def _load_fused_module_from_checkpoint(
        self, param: BasevLLMParameter, loaded_weight: torch.Tensor
    ):
990
        """
991
        Handle special case for models where QKV layers are already
992
        fused on disk. In this case, we have no shard id. This function
993
        determines the shard id by splitting these layers and then calls
994
995
996
997
998
999
1000
1001
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """
        shard_offsets = [
            # (shard_id, shard_offset, shard_size)
            ("q", 0, self.total_num_heads * self.head_size),
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
            (
                "k",
                self.total_num_heads * self.head_size,
                self.total_num_kv_heads * self.head_size,
            ),
            (
                "v",
                (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
                self.total_num_kv_heads * self.head_size,
            ),
1012
1013
1014
1015
1016
1017
        ]

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
            if (
                isinstance(param, (PackedColumnParameter, PackedvLLMParameter))
                and param.packed_dim == param.output_dim
            ):
                shard_size, shard_offset = param.adjust_shard_indexes_for_packing(
                    shard_size=shard_size, shard_offset=shard_offset
                )

            loaded_weight_shard = loaded_weight.narrow(
                param.output_dim, shard_offset, shard_size
            )
1029
1030
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

1031
1032
1033
1034
1035
1036
    def weight_loader_v2(
        self,
        param: BasevLLMParameter,
        loaded_weight: torch.Tensor,
        loaded_shard_id: Optional[str] = None,
    ):
1037
        if loaded_shard_id is None:  # special case for certain models
1038
            if isinstance(param, PerTensorScaleParameter):
1039
1040
1041
                param.load_qkv_weight(
                    loaded_weight=loaded_weight, shard_id=0, tp_rank=self.tp_rank
                )
1042
                return
1043
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
1044
                param.load_qkv_weight(loaded_weight=loaded_weight, tp_rank=self.tp_rank)
1045
                return
1046
            # TODO: @dsikka - move to parameter.py
1047
1048
1049
1050
1051
1052
1053
1054
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id in ["q", "k", "v"]

        shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
        shard_size = self._get_shard_size_mapping(loaded_shard_id)

1055
1056
1057
        # Note(simon): This is needed for Qwen3's fp8 quantization.
        if isinstance(param, BlockQuantScaleParameter):
            assert self.quant_method is not None
1058
1059
1060
1061
            # Assume the weight block size has been set by quant method
            assert hasattr(self, "weight_block_size")
            weight_block_size = self.weight_block_size
            assert weight_block_size is not None
1062
1063
1064
1065
            block_n, _ = weight_block_size[0], weight_block_size[1]
            shard_offset = (shard_offset + block_n - 1) // block_n
            shard_size = (shard_size + block_n - 1) // block_n

1066
1067
1068
1069
1070
1071
1072
1073
        param.load_qkv_weight(
            loaded_weight=loaded_weight,
            num_heads=self.num_kv_head_replicas,
            shard_id=loaded_shard_id,
            shard_offset=shard_offset,
            shard_size=shard_size,
            tp_rank=self.tp_rank,
        )
1074

1075
1076
1077
1078
1079
1080
    def weight_loader(
        self,
        param: Parameter,
        loaded_weight: torch.Tensor,
        loaded_shard_id: Optional[str] = None,
    ):
1081
1082
1083
1084
        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
1085
        if is_gguf_weight_type:
1086
            idx_map = {"q": 0, "k": 1, "v": 2}
1087
1088
1089
1090
            if loaded_shard_id is not None:
                param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
1091
                param.shard_weight_type = {k: loaded_weight.item() for k in idx_map}
1092
1093
            return

1094
1095
        if is_gguf_weight:
            output_dim = getattr(param, "output_dim", None)
1096
1097
            shard_size = loaded_weight.size(output_dim) // self.tp_size
            start_idx = self.tp_rank * shard_size
1098

1099
            if loaded_shard_id is not None:
1100
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
1101
1102
1103
1104
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
1105

1106
1107
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
1108

1109
1110
        # Special case for per-tensor scales in fused case.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
1111

1112
        if loaded_shard_id is None:
1113
1114
            # Loaded weight is already fused on disk (qkv).
            # (e.g., Phi-3's qkv_proj).
1115
            if output_dim is None:
1116
                if needs_scalar_to_array:
1117
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
1118
1119
                        param_data, loaded_weight, 0
                    )
1120

1121
1122
1123
1124
1125
1126
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            shard_offsets = [
                # (shard_id, shard_offset, shard_size)
                ("q", 0, self.total_num_heads * self.head_size),
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
                (
                    "k",
                    self.total_num_heads * self.head_size,
                    self.total_num_kv_heads * self.head_size,
                ),
                (
                    "v",
                    (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
                    self.total_num_kv_heads * self.head_size,
                ),
1137
            ]
1138
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1139

1140
1141
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
1142
                # Special case for Quantized Weights.
1143
1144
1145
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
1146
1147
                    shard_size = shard_size // param.packed_factor
                    shard_offset = shard_offset // param.packed_factor
1148

1149
                    # Special case for Marlin.
1150
                    shard_size, shard_offset = adjust_marlin_shard(
1151
1152
                        param, shard_size, shard_offset
                    )
1153

1154
1155
1156
                if use_bitsandbytes_4bit:
                    orig_qkv_offsets = {
                        "q": (0, self.total_num_heads * self.head_size),
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
                        "k": (
                            self.total_num_heads * self.head_size,
                            self.total_num_kv_heads * self.head_size,
                        ),
                        "v": (
                            (self.total_num_heads + self.total_num_kv_heads)
                            * self.head_size,
                            self.total_num_kv_heads * self.head_size,
                        ),
                        "total": (
                            (self.total_num_heads + 2 * self.total_num_kv_heads)
                            * self.head_size,
                            0,
                        ),
1171
1172
1173
                    }

                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
1174
1175
                        param, orig_qkv_offsets, shard_id
                    )
1176

1177
                loaded_weight_shard = loaded_weight.narrow(
1178
1179
                    output_dim, shard_offset, shard_size
                )
1180
1181
1182
1183
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id in ["q", "k", "v"]
1184
1185

        # If output dim is defined, use the default loading process.
1186
1187
1188
1189
1190
1191
1192
1193
        if output_dim is not None:
            if loaded_shard_id == "q":
                shard_offset = 0
                shard_size = self.num_heads * self.head_size
            elif loaded_shard_id == "k":
                shard_offset = self.num_heads * self.head_size
                shard_size = self.num_kv_heads * self.head_size
            elif loaded_shard_id == "v":
1194
                shard_offset = (self.num_heads + self.num_kv_heads) * self.head_size
1195
                shard_size = self.num_kv_heads * self.head_size
1196
            # Special case for Quantized Weights.
1197
1198
1199
1200
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
1201
1202
                shard_size = shard_size // param.packed_factor
                shard_offset = shard_offset // param.packed_factor
1203

1204
                # Special case for Marlin.
1205
                shard_size, shard_offset = adjust_marlin_shard(
1206
1207
                    param, shard_size, shard_offset
                )
1208

1209
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1210
1211
1212
1213
1214
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

1215
            if use_bitsandbytes_4bit:
1216
1217
                orig_qkv_offsets = {
                    "q": (0, self.num_heads * self.head_size),
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
                    "k": (
                        self.num_heads * self.head_size,
                        self.num_kv_heads * self.head_size,
                    ),
                    "v": (
                        (self.num_heads + self.num_kv_heads) * self.head_size,
                        self.num_kv_heads * self.head_size,
                    ),
                    "total": (
                        (self.num_heads + 2 * self.num_kv_heads) * self.head_size,
                        0,
                    ),
1230
                }
1231
                shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
1232
1233
                    param, orig_qkv_offsets, loaded_shard_id
                )
1234

1235
            param_data = param_data.narrow(output_dim, shard_offset, shard_size)
1236
            if loaded_shard_id == "q":
1237
                shard_id = self.tp_rank
1238
            else:
1239
                shard_id = self.tp_rank // self.num_kv_head_replicas
1240
            start_idx = shard_id * shard_size
1241

1242
            if not is_sharded_weight:
1243
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
1244

1245
1246
1247
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
1248
1249
                param_data, loaded_weight, loaded_shard_id
            )
1250
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
1251
1252
1253
1254
1255
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "QKVParallelLinear, assume the weight is the same "
1256
1257
                    "for all partitions."
                )
1258

1259
1260
1261
1262
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)


1263
@CustomOp.register("row_parallel_linear")
1264
class RowParallelLinear(LinearBase):
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
    """Linear layer with row parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its first dimension and X along its second dimension as:
               -   -
              | A_1 |
              | .   |
          A = | .   |        X = [X_1, ..., X_p]
              | .   |
              | A_p |
               -   -
    Arguments:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias. Note that bias is not parallelized.
        input_is_parallel: If true, we assume that the input is already
                           split across the GPUs and we do not split
                           again.
        skip_bias_add: This was added to enable performance optimization where
                       bias can be fused with other element-wise operations.
                       We skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
1287
1288
1289
        reduce_results: If true, call all-reduce on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y = X_iA_i
1290
        quant_config: Quantization configure.
1291
1292
1293
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.down_proj)
        return_bias: If true, return bias together with outputs in forward pass.
1294
        disable_tp: If true, weights matrix won't be sharded through tp rank.
1295
1296
    """

1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        input_is_parallel: bool = True,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        reduce_results: bool = True,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        *,
        return_bias: bool = True,
1310
        disable_tp: bool = False,
1311
    ):
1312
        # Divide the weight matrix along the first dimension.
1313
1314
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
1315
1316
1317
1318
        self.input_size_per_partition = divide(input_size, self.tp_size)
        self.output_size_per_partition = output_size
        self.output_partition_sizes = [output_size]

1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
        super().__init__(
            input_size,
            output_size,
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
1329

1330
1331
1332
        self.input_is_parallel = input_is_parallel
        self.reduce_results = reduce_results

1333
        assert self.quant_method is not None
1334
1335
1336
        self.quant_method.create_weights(
            layer=self,
            input_size_per_partition=self.input_size_per_partition,
1337
            output_partition_sizes=self.output_partition_sizes,
1338
1339
1340
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
1341
            weight_loader=(
1342
1343
1344
1345
1346
                self.weight_loader_v2
                if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
                else self.weight_loader
            ),
        )
1347
        if not reduce_results and (bias and not skip_bias_add):
1348
1349
1350
1351
            raise ValueError(
                "When not reduce the results, adding bias to the "
                "results can lead to incorrect results"
            )
1352
1353

        if bias:
1354
1355
1356
1357
1358
1359
1360
1361
            self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype))
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
1362
1363
        else:
            self.register_parameter("bias", None)
1364
        self.update_param_tp_status()
1365
1366
1367

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        input_dim = getattr(param, "input_dim", None)
1368
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1369
1370
1371
1372
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
1373
1374
1375
1376
1377
1378
1379
1380
1381

        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
1382
1383
            weight_shape = list(loaded_weight.shape)
            if input_dim:
1384
                weight_shape[input_dim] = weight_shape[input_dim] // self.tp_size
1385
            param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
1386

1387
        param_data = param.data
1388
        if input_dim is not None and not is_sharded_weight:
1389
            shard_size = param_data.shape[input_dim]
1390
            start_idx = self.tp_rank * shard_size
1391
            loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)
1392

1393
1394
1395
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
1396
1397
            loaded_weight = loaded_weight.reshape(1)

1398
1399
1400
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

1401
    def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor):
1402
1403
1404
1405
1406
1407
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)

1408
1409
        param.load_row_parallel_weight(loaded_weight=loaded_weight)

1410
    def forward(
1411
1412
        self,
        input_,
1413
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
1414
1415
1416
1417
        if self.input_is_parallel:
            input_parallel = input_
        else:
            splitted_input = split_tensor_along_last_dim(
1418
1419
                input_, num_partitions=self.tp_size
            )
1420
            input_parallel = splitted_input[self.tp_rank].contiguous()
1421
1422

        # Matrix multiply.
1423
        assert self.quant_method is not None
1424
1425
1426
        # Only fuse bias add into GEMM for rank 0 (this ensures that
        # bias will not get added more than once in TP>1 case)
        bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
1427
1428
        output_parallel = self.quant_method.apply(self, input_parallel, bias_)

1429
        if self.reduce_results and self.tp_size > 1:
1430
            output = tensor_model_parallel_all_reduce(output_parallel)
1431
        else:
1432
1433
1434
            output = output_parallel

        output_bias = self.bias if self.skip_bias_add else None
1435

1436
1437
        if not self.return_bias:
            return output
1438
        return output, output_bias
1439
1440

    def extra_repr(self) -> str:
1441
        s = f"in_features={self.input_size_per_partition}"
1442
1443
1444
1445
1446
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={self.tp_size}"
        s += f", reduce_results={self.reduce_results}"
        return s
1447
1448


1449
@CustomOp.register("qkv_cross_parallel_linear")
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
class QKVCrossParallelLinear(LinearBase):
    """Linear layers for efficient cross-attention's QKV transformation.

    Args:
        hidden_size: input hidden state size of the transformer.
        head_size: size of each attention head.
        total_num_heads: total number of attention query heads.
        total_num_kv_heads: total number of attention key/value heads. If
                            None, assume total_num_kv_heads = total_num_heads.
        bias: If true, add bias.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
        quant_config: Quantization configure.
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
    """
1468

1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
    def __init__(
        self,
        hidden_size: int,
        head_size: int,
        total_num_heads: int,
        total_num_kv_heads: Optional[int] = None,
        bias: bool = True,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
1481
1482
1483
        # input_size and output_size are not used, just for alignment
        input_size = hidden_size
        output_size = (total_num_heads + (total_num_kv_heads or 0)) * head_size
1484
1485
1486
1487
1488
1489
1490
1491
        super().__init__(
            input_size=input_size,
            output_size=output_size,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            quant_config=quant_config,
            prefix=prefix,
        )
1492
1493
1494

        self.quant_config = quant_config

1495
        # Empty placeholders for loading as a single module.
1496
1497
        placeholder_size = 0
        assert self.quant_method is not None
1498
1499
1500
1501
1502
1503
1504
1505
1506
        self.quant_method.create_weights(
            self,
            placeholder_size,
            [placeholder_size],
            placeholder_size,
            placeholder_size,
            self.params_dtype,
            weight_loader=self.weight_loader,
        )
1507

1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
        # Use a dictionary to avoid submodules parameters auto-registration:
        # drop-in replacement for a `QKVParallelLinear` module.
        self.proj = dict()
        self.proj["q_proj_decoder"] = ColumnParallelLinear(
            input_size=hidden_size,
            output_size=total_num_heads * head_size,
            bias=bias,
            quant_config=quant_config,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
1518
1519
            prefix=f"{prefix}.q_proj_decoder",
        )
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529

        self.proj["kv_proj_encoder"] = QKVParallelLinear(
            hidden_size=hidden_size,
            head_size=head_size,
            total_num_heads=0,
            total_num_kv_heads=total_num_kv_heads,
            bias=bias,
            quant_config=quant_config,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
1530
1531
            prefix=f"{prefix}.kv_proj_encoder",
        )
1532
1533

        # `kv_proj_encoder.num_kv_heads` accounts for sharding with tp>1.
1534
        self.q_size = self.q_proj_decoder.output_size_per_partition
1535
1536
1537
1538
        self.kv_size = self.kv_proj_encoder.num_kv_heads * head_size

        if bias:
            self.bias = torch.nn.Parameter()
1539
1540
1541
1542
1543
1544
1545
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader_v1,
                },
            )
1546
1547
        else:
            self.bias = None
1548

1549
1550
1551
1552
1553
    def process_weights_after_loading(self):
        for layer in self.proj.values():
            if self.quant_method is not None:
                self.quant_method.process_weights_after_loading(layer)

1554
    @property
1555
1556
1557
    def q_proj_decoder(self) -> ColumnParallelLinear:
        layer = self.proj["q_proj_decoder"]
        for name, param in self.named_parameters():
1558
1559
            target_param = getattr(layer, name, None)
            if target_param is not None:
1560
                self.sync_weight_attrs(param, target_param, mode="q_proj_decoder")
1561
        return layer
1562
1563

    @property
1564
1565
1566
    def kv_proj_encoder(self) -> QKVParallelLinear:
        layer = self.proj["kv_proj_encoder"]
        for name, param in self.named_parameters():
1567
1568
            target_param = getattr(layer, name, None)
            if target_param is not None:
1569
                self.sync_weight_attrs(param, target_param, mode="kv_proj_encoder")
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
        return layer

    def sync_weight_attrs(
        self,
        src_param: nn.Parameter,
        tgt_param: nn.Parameter,
        mode: Literal["q_proj_decoder", "kv_proj_encoder"],
    ):
        missing_attrs_dict = {
            k: getattr(src_param, k)
1580
            for k in (set(vars(src_param).keys()) - set(vars(tgt_param).keys()))
1581
1582
        }
        # TODO(Isotr0py): handle bitsandbytes 8bit
1583
1584
        use_bitsandbytes_4bit = getattr(src_param, "use_bitsandbytes_4bit", False)
        if missing_attrs_dict and use_bitsandbytes_4bit:
1585
            q_proj_attrs, kv_proj_attrs = left_shift_bitsandbytes_4bit_shard(
1586
1587
                missing_attrs_dict
            )
1588
1589
1590
1591
1592
1593
            if mode == "q_proj_decoder":
                set_weight_attrs(tgt_param, q_proj_attrs)
            elif mode == "kv_proj_encoder":
                set_weight_attrs(tgt_param, kv_proj_attrs)
        else:
            set_weight_attrs(tgt_param, missing_attrs_dict)
1594

1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
    def _is_same_param(
        self,
        src_param: torch.nn.Parameter,
        map_param: torch.nn.Parameter,
    ) -> bool:
        """Check if two parameters are exactly pointing to same things."""
        # ignore weight_loader because it's always different
        key_to_ignore = ["weight_loader", "_weight_loader"]
        has_same_type_name = type(src_param) is type(map_param)
        src_param_attrs = {
1605
            k: v for k, v in src_param.__dict__.items() if k not in key_to_ignore
1606
1607
        }
        map_param_attrs = {
1608
            k: v for k, v in map_param.__dict__.items() if k not in key_to_ignore
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
        }
        has_same_attrs = src_param_attrs == map_param_attrs
        return has_same_type_name and has_same_attrs

    def select_proj_params(
        self,
        layer: nn.Module,
        param: nn.Parameter,
    ) -> nn.Parameter:
        """
1619
        Given the placeholder param,
1620
1621
1622
        return the corresponding param in the proj layers.
        """
        target_param_list = [
1623
            v for _, v in layer.named_parameters() if self._is_same_param(param, v)
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
        ]
        assert len(target_param_list) == 1
        target_param = target_param_list[0]
        return target_param

    def forward(  # type: ignore[override]
        self,
        decoder_hidden_states: torch.Tensor,
        encoder_hidden_states: torch.Tensor,
    ) -> tuple[torch.Tensor, ...]:
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
        q, _ = self.q_proj_decoder(decoder_hidden_states)
        if encoder_hidden_states is None:
            # Encoder KV already cached.
            k = None
            v = None
        else:
            # Prefill phase, encoder KV cached here.
            kv_enc, _ = self.kv_proj_encoder(encoder_hidden_states)
            # Split kv in half
            k, v = kv_enc.split(self.kv_size, dim=-1)
        return q, k, v

1646
1647
1648
1649
1650
1651
    def weight_loader_v1(
        self,
        param: torch.nn.Parameter,
        loaded_weight: torch.Tensor,
        loaded_shard_id: Optional[str] = None,
    ):
1652
1653
        # just like all other parameters, does not yet
        # support loading bias with weight_loader_v2
1654
        layer = self.q_proj_decoder if loaded_shard_id == "q" else self.kv_proj_encoder
1655
        target_param = self.select_proj_params(layer, param)
1656
        shard_id_args = (loaded_shard_id,) if loaded_shard_id != "q" else ()
1657
1658
        layer.weight_loader(target_param, loaded_weight, *shard_id_args)

1659
1660
1661
1662
1663
1664
1665
    def weight_loader(
        self,
        param: torch.nn.Parameter,
        loaded_weight: torch.Tensor,
        loaded_shard_id: Optional[str] = None,
    ):
        layer = self.q_proj_decoder if loaded_shard_id == "q" else self.kv_proj_encoder
1666
        target_param = self.select_proj_params(layer, param)
1667
        shard_id_args = (loaded_shard_id,) if loaded_shard_id != "q" else ()
1668
1669
1670
1671
        if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED:
            layer.weight_loader_v2(target_param, loaded_weight, *shard_id_args)
        else:
            layer.weight_loader(target_param, loaded_weight, *shard_id_args)
1672
1673
1674

    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
1675
        s += f", q_size={self.q_size}"
1676
1677
1678
1679
1680
        s += f", kv_size={self.kv_size}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={get_tensor_model_parallel_world_size()}"
        s += ", gather_output=False"
        return s