linear.py 65.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
from abc import abstractmethod
6
from typing import Any, Literal, Optional, Union
7
8

import torch
9
import torch.nn as nn
10
from torch.nn.parameter import Parameter, UninitializedParameter
11

12
13
14
15
16
17
18
19
from vllm.distributed import (
    divide,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
    split_tensor_along_last_dim,
    tensor_model_parallel_all_gather,
    tensor_model_parallel_all_reduce,
)
20
from vllm.logger import init_logger
21
from vllm.model_executor.custom_op import CustomOp
22
from vllm.model_executor.layers.quantization.base_config import (
23
24
25
    QuantizationConfig,
    QuantizeMethodBase,
)
26
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
27
28
29
30
31
32
33
34
35
from vllm.model_executor.parameter import (
    BasevLLMParameter,
    BlockQuantScaleParameter,
    ModelWeightParameter,
    PackedColumnParameter,
    PackedvLLMParameter,
    PerTensorScaleParameter,
    RowvLLMParameter,
)
36
from vllm.model_executor.utils import set_weight_attrs
37
from vllm.platforms import current_platform
38
from vllm.utils import GiB_bytes
39
40
41

logger = init_logger(__name__)

42
WEIGHT_LOADER_V2_SUPPORTED = [
43
    "UnquantizedLinearMethod",
44
    "CompressedTensorsLinearMethod",
45
    "CompressedTensorsLinearTransformMethod",
46
47
    "BitBLASLinearMethod",
    "GPTQBitBLASLinearMethod",
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
    "AWQMarlinLinearMethod",
    "AWQLinearMethod",
    "GPTQMarlinLinearMethod",
    "Fp8LinearMethod",
    "MarlinLinearMethod",
    "GPTQMarlin24LinearMethod",
    "TPUInt8LinearMethod",
    "GPTQLinearMethod",
    "FBGEMMFp8LinearMethod",
    "ModelOptFp8LinearMethod",
    "IPEXAWQLinearMethod",
    "IPEXGPTQLinearMethod",
    "HQQMarlinMethod",
    "QuarkLinearMethod",
    "ModelOptNvFp4LinearMethod",
63
    "PetitNvFp4LinearMethod",
64
]
65

66

67
68
69
def adjust_bitblas_shard(param, shard_size, shard_offset):
    bitblas_tile_size = getattr(param, "bitblas_tile_size", None)
    if bitblas_tile_size is not None:
70
        return (shard_size // bitblas_tile_size, shard_offset // bitblas_tile_size)
71
72
73
74

    return shard_size, shard_offset


75
76
77
78
79
80
81
82
def adjust_marlin_shard(param, shard_size, shard_offset):
    marlin_tile_size = getattr(param, "marlin_tile_size", None)
    if marlin_tile_size is None:
        return shard_size, shard_offset

    return shard_size * marlin_tile_size, shard_offset * marlin_tile_size


83
84
85
def adjust_bitsandbytes_4bit_shard(
    param: Parameter, shard_offsets: dict[str, tuple[int, int]], loaded_shard_id: str
) -> tuple[int, int]:
86
87
    """Adjust the quantization offsets and sizes for BitsAndBytes sharding."""

88
89
    total, _ = shard_offsets["total"]
    orig_offset, orig_size = shard_offsets[loaded_shard_id]
90
91
92
93
94
95
96
97

    quantized_total = param.data.shape[0]
    quantized_offset = orig_offset * quantized_total // total
    quantized_size = orig_size * quantized_total // total

    return quantized_size, quantized_offset


98
99
100
def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
    """For fused modules (QKV and MLP) we have an array of length
    N that holds 1 scale for each "logical" matrix. So the param
101
102
    is an array of length N. The loaded_weight corresponds to
    one of the shards on disk. Here, we slice the param based on
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
    the shard_id for loading.
    """
    qkv_idxs = {"q": 0, "k": 1, "v": 2}

    if isinstance(shard_id, str):
        shard_id = qkv_idxs[shard_id]
    elif not isinstance(shard_id, int):
        raise ValueError(f"Unknown Shard Id {shard_id}")

    # AutoFP8 scales do not have a shape
    # compressed-tensors scales do have a shape
    if len(loaded_weight.shape) != 0:
        assert loaded_weight.shape[0] == 1
        loaded_weight = loaded_weight[0]

    return param[shard_id], loaded_weight


121
122
123
124
125
126
127
128
# TODO(Isotr0py): We might need a more flexible structure to handle
# bitsandbytes shard offsets.
def left_shift_bitsandbytes_4bit_shard(bnb_weight_attrs: dict[str, Any]):
    """
    Separate the BitsAndBytes 4-bit shard.

    For example, given bnb weight attributes as below:
    {
129
        'bnb_shard_offsets': array([0, 4, 8, 16]),
130
131
132
133
134
        'bnb_quant_state': {0: ..., 1: ..., 2: ...},
    }

    The function will return:
    {
135
        'bnb_shard_offsets': array([0, 4]),
136
137
138
139
140
141
142
143
144
145
146
147
148
149
        'bnb_quant_state': {0: ...},
    }
    and
    {
        'bnb_shard_offsets': array([0, 4, 12]),
        'bnb_quant_state': {0: ..., 1: ...},
    }
    """
    shard_offsets = bnb_weight_attrs["bnb_shard_offsets"]
    offset_l = shard_offsets[:2]
    offset_r = shard_offsets[1:] - shard_offsets[1]
    quant_state_l = {0: bnb_weight_attrs["bnb_quant_state"][0]}
    quant_state_r = {
        i - 1: bnb_weight_attrs["bnb_quant_state"][i]
150
        for i in range(1, len(shard_offsets) - 1)
151
152
153
154
155
156
    }
    left = dict(bnb_shard_offsets=offset_l, bnb_quant_state=quant_state_l)
    right = dict(bnb_shard_offsets=offset_r, bnb_quant_state=quant_state_r)
    return left, right


157
class LinearMethodBase(QuantizeMethodBase):
158
159
160
    """Base class for different (maybe quantized) linear methods."""

    @abstractmethod
161
162
163
164
165
166
167
168
169
170
171
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        """Create weights for a linear layer.
172
           The weights will be set as attributes of the layer.
173

174
175
176
        Args:
            layer: The layer that is using the LinearMethodBase factory.
            input_size_per_partition: Size of the weight input dim on rank X.
177
            output_partition_sizes: Sizes of the output dim of each logical
178
179
180
181
182
183
                weight on rank X. E.g., output_partition_sizes for QKVLinear
                is a list contains the width of Wq, Wk, Wv on rank X.
            input_size: Size of the input dim of the weight across all ranks.
            output_size: Size of the output dim of the weight across all ranks.
            params_dtype: Datatype of the parameters.
        """
184
185
186
        raise NotImplementedError

    @abstractmethod
187
188
189
190
191
192
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
193
194
        """Apply the weights in layer to the input tensor.
        Expects create_weights to have been called before on the layer."""
195
196
197
198
        raise NotImplementedError


class UnquantizedLinearMethod(LinearMethodBase):
199
    """Linear method without quantization."""
200

201
202
203
204
205
206
207
208
209
210
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
211
212
213
214
215
        # This method creates unquantized linear weights.
        # The weights are not quantized, and they are not sharded.
        # The amount of memory allocated for the weights is
        # sum(output_partition_sizes) * input_size_per_partition.
        try:
216
            weight_loader = extra_weight_attrs.pop("weight_loader")
217
218
219
220
221
222
223
224
225
226
            weight = ModelWeightParameter(
                data=torch.empty(
                    sum(output_partition_sizes),
                    input_size_per_partition,
                    dtype=params_dtype,
                ),
                input_dim=1,
                output_dim=0,
                weight_loader=weight_loader,
            )
227
228
229
230
        except torch.cuda.OutOfMemoryError as e:
            logger.error("Failed to create unquantized linear weights: %s", e)
            if torch.cuda.is_available():
                logger.debug("CUDA device: %s", torch.cuda.current_device())
231
232
233
234
235
236
                logger.debug(
                    "Allocated: %.2f GiB", torch.cuda.memory_allocated() / GiB_bytes
                )
                logger.debug(
                    "Reserved: %.2f GiB", torch.cuda.memory_reserved() / GiB_bytes
                )
237
238
239
            raise RuntimeError(
                "Failed to create unquantized linear weights. "
                "This may be caused by insufficient memory to allocate "
240
241
                "the weight."
            ) from e
242

243
244
        layer.register_parameter("weight", weight)
        set_weight_attrs(weight, extra_weight_attrs)
245

246
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
247
        if current_platform.is_cpu():
248
            from vllm.model_executor.layers.utils import dispatch_cpu_unquantized_gemm
249

250
            dispatch_cpu_unquantized_gemm(layer, remove_weight=True)
251

252
253
254
255
256
257
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
258
        return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
259
260


261
class LinearBase(CustomOp):
262
    """Base linear layer.
263
264
265
266
267
268

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
269
        quant_config: Quantization configure.
270
        prefix: Prefix for parameter names.
271
        return_bias: If true, return bias together with outputs in forward pass.
272
        disable_tp: If true, tensor parallelism will be disabled for this layer.
273
274
275
276
277
278
279
280
    """

    def __init__(
        self,
        input_size: int,
        output_size: int,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
281
        quant_config: Optional[QuantizationConfig] = None,
282
        prefix: str = "",
283
284
        *,
        return_bias: bool = True,
285
        disable_tp: bool = False,
286
287
288
289
290
291
292
293
294
295
    ):
        super().__init__()

        # Keep input parameters
        self.input_size = input_size
        self.output_size = output_size
        self.skip_bias_add = skip_bias_add
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype
296
297
        self.quant_config = quant_config
        self.prefix = prefix
298
        if quant_config is None:
299
            self.quant_method: Optional[QuantizeMethodBase] = UnquantizedLinearMethod()
300
        else:
301
            self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
302
        self.return_bias = return_bias
303
        self.disable_tp = disable_tp
304
305
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
306

307
    def update_param_tp_status(self):
308
309
310
311
        for param in self.parameters():
            if isinstance(param, BasevLLMParameter):
                param.tp_rank = self.tp_rank
                param.tp_size = self.tp_size
312
313


314
@CustomOp.register("replicated_linear")
315
316
317
318
319
320
321
322
323
324
class ReplicatedLinear(LinearBase):
    """Replicated linear layer.

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        bias: If true, add bias.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
        quant_config: Quantization configure.
325
326
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
327
        return_bias: If true, return bias together with outputs in forward pass.
328
        disable_tp: Take no effect for replicated linear layers.
329
330
    """

331
332
333
334
335
336
337
338
339
340
341
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        *,
        return_bias: bool = True,
342
        disable_tp: bool = False,
343
    ):
344
345
346
347
348
349
        # If MergedReplicatedLinear, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = self.output_sizes
        else:
            self.output_partition_sizes = [output_size]

350
351
352
353
354
355
356
357
358
359
        super().__init__(
            input_size,
            output_size,
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
360

361
362
        # All the linear layer supports quant method.
        assert self.quant_method is not None
363
364
365
366
367
368
369
370
371
        self.quant_method.create_weights(
            self,
            self.input_size,
            self.output_partition_sizes,
            self.input_size,
            self.output_size,
            self.params_dtype,
            weight_loader=self.weight_loader,
        )
372

373
374
        if bias:
            self.bias = Parameter(
375
376
377
378
379
380
381
382
383
                torch.empty(self.output_size, dtype=self.params_dtype)
            )
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
384
385
386
        else:
            self.register_parameter("bias", None)

387
388
389
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        # If the weight on disk does not have a shape, give it one
        # (such scales for AutoFp8).
390
391
392
393
394
395
396
397
398
399
400
        # Special case for GGUF

        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
            param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)

401
402
403
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)

404
405
        assert param.size() == loaded_weight.size(), (
            f"Tried to load weights of size {loaded_weight.size()}"
406
407
            f"to a parameter of size {param.size()}"
        )
408
409
        param.data.copy_(loaded_weight)

410
    def forward(
411
412
        self,
        x: torch.Tensor,
413
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
414
        bias = self.bias if not self.skip_bias_add else None
415
        assert self.quant_method is not None
416

417
        output = self.quant_method.apply(self, x, bias)
418
        output_bias = self.bias if self.skip_bias_add else None
419

420
421
        if not self.return_bias:
            return output
422
423
        return output, output_bias

424
425
426
427
428
429
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        return s

430

431
@CustomOp.register("column_parallel_linear")
432
class ColumnParallelLinear(LinearBase):
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
    """Linear layer with column parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its second dimension as A = [A_1, ..., A_p].

    Args:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y_i = XA_i
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
449
        quant_config: Quantization configure.
James Fleming's avatar
James Fleming committed
450
451
        output_sizes: list of output sizes packed into one output, like for QKV
                       the list would be size 3.
452
        prefix: The name of the layer in the state dict, including all parents
453
                        (e.g. model.layers.0.qkv_proj)
454
455
        return_bias: If true, return bias together with outputs in forward pass.
        disable_tp: If true, weights matrix won't be sharded through tp rank.
456
457
    """

458
459
460
461
462
463
464
465
466
467
468
469
470
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
        output_sizes: Optional[list[int]] = None,
        prefix: str = "",
        *,
        return_bias: bool = True,
471
        disable_tp: bool = False,
472
    ):
473
        # Divide the weight matrix along the last dimension.
474
475
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
476
477
        self.input_size_per_partition = input_size
        self.output_size_per_partition = divide(output_size, self.tp_size)
478
479
480
481
        self.output_partition_sizes = [self.output_size_per_partition]
        # If QKV or MergedColumn, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = [
482
                divide(output_size, self.tp_size) for output_size in self.output_sizes
483
484
            ]

485
486
487
488
489
490
491
492
493
494
        super().__init__(
            input_size,
            output_size,
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
495
496
497

        self.gather_output = gather_output

James Fleming's avatar
James Fleming committed
498
499
        if output_sizes is None:
            output_sizes = [output_size]
500

501
        assert self.quant_method is not None
502
503
        self.quant_method.create_weights(
            layer=self,
504
            input_size_per_partition=self.input_size_per_partition,
505
506
507
508
            output_partition_sizes=self.output_partition_sizes,
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
509
            weight_loader=(
510
511
512
513
514
                self.weight_loader_v2
                if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
                else self.weight_loader
            ),
        )
515
516
        if bias:
            self.bias = Parameter(
517
518
519
520
521
522
523
524
525
                torch.empty(self.output_size_per_partition, dtype=params_dtype)
            )
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
526
527
        else:
            self.register_parameter("bias", None)
528
        self.update_param_tp_status()
529
530
531

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        output_dim = getattr(param, "output_dim", None)
532

533
534
535
536
537
538
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

539
540
541
542
543
544
545
546
        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
547
548
            final_shape = list(loaded_weight.shape)
            if output_dim is not None:
549
                assert final_shape[output_dim] % self.tp_size == 0
550
                final_shape[output_dim] = final_shape[output_dim] // self.tp_size
551
            param.materialize(final_shape, dtype=loaded_weight.dtype)
552

553
        param_data = param.data
554
        if output_dim is not None and not is_sharded_weight:
555
            shard_size = param_data.shape[output_dim]
556
            start_idx = self.tp_rank * shard_size
557
            loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
558
559
560
561
562

        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)
563

564
565
566
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

567
    def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor):
568
569
570
571
572
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)
573
574
        param.load_column_parallel_weight(loaded_weight=loaded_weight)

575
    def forward(
576
577
        self,
        input_,
578
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
579
580
581
        bias = self.bias if not self.skip_bias_add else None

        # Matrix multiply.
582
        assert self.quant_method is not None
583
        output_parallel = self.quant_method.apply(self, input_, bias)
584

585
        if self.gather_output and self.tp_size > 1:
586
587
588
589
590
            # All-gather across the partitions.
            output = tensor_model_parallel_all_gather(output_parallel)
        else:
            output = output_parallel
        output_bias = self.bias if self.skip_bias_add else None
591
592
        if not self.return_bias:
            return output
593
594
        return output, output_bias

595
596
597
598
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size_per_partition}"
        s += f", bias={self.bias is not None}"
599
        s += f", tp_size={self.tp_size}"
600
601
602
        s += f", gather_output={self.gather_output}"
        return s

603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621

class MergedColumnParallelLinear(ColumnParallelLinear):
    """Packed linear layers with column parallelism.

    Similar to ColumnParallelLinear, but the weight matrix is concatenated
    along the output dimension. When the weight matrix is loaded, the
    different partitions are sharded separately.

    Args:
        input_size: input dimension of the linear layer.
        output_sizes: list of output dimensions of the linear layer.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make the output
                       available to all GPUs, otherwise, every GPU will have
                       its own output.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
622
        quant_config: Quantization configure.
623
624
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
625
        return_bias: If true, return bias together with outputs in forward pass.
626
627
        disable_tp: If true, all weights matrix won't be sharded, this layer
                    will be treated as a "Replicated" MergedLinear.
628
629
    """

630
631
632
633
634
635
636
637
638
639
640
641
    def __init__(
        self,
        input_size: int,
        output_sizes: list[int],
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        *,
        return_bias: bool = True,
642
        disable_tp: bool = False,
643
    ):
644
        self.output_sizes = output_sizes
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0

        assert all(output_size % self.tp_size == 0 for output_size in output_sizes)
        super().__init__(
            input_size=input_size,
            output_size=sum(output_sizes),
            bias=bias,
            gather_output=gather_output,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            quant_config=quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
James Fleming's avatar
James Fleming committed
661

662
663
664
665
666
667
    def weight_loader(
        self,
        param: Parameter,
        loaded_weight: torch.Tensor,
        loaded_shard_id: Optional[int] = None,
    ):
668
669
670
671
672
        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
673
674
675
676
677
            if loaded_shard_id is not None:
                param.data[loaded_shard_id].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
                param.shard_weight_type = {
678
                    i: loaded_weight.item() for i, _ in enumerate(self.output_sizes)
679
                }
680
681
            return

682
683
        if is_gguf_weight:
            output_dim = getattr(param, "output_dim", None)
684
685
            shard_size = loaded_weight.size(output_dim) // self.tp_size
            start_idx = self.tp_rank * shard_size
686

687
            if loaded_shard_id is not None:
688
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
689
690
691
692
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
693

694
695
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
696
697
        # Special case for per-tensor scale to load scalar into fused array.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
698

699
        if loaded_shard_id is None:
700
701
            # Loaded weight is already fused on disk (mlp).
            # (e.g., Phi-3's gate_up_proj).
702
            if output_dim is None:
703
                if needs_scalar_to_array:
704
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
705
706
                        param_data, loaded_weight, 0
                    )
707

708
709
710
711
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            current_shard_offset = 0
712
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
713
            shard_offsets: list[tuple[int, int, int]] = []
714
715
716
717
718
            for i, output_size in enumerate(self.output_sizes):
                shard_offsets.append((i, current_shard_offset, output_size))
                current_shard_offset += output_size
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
719
                # Special case for Quantization.
720
721
722
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
723
724
                    shard_size = shard_size // param.packed_factor
                    shard_offset = shard_offset // param.packed_factor
725
                    # Special case for Marlin.
726
                    shard_size, shard_offset = adjust_marlin_shard(
727
728
                        param, shard_size, shard_offset
                    )
729

730
                shard_size, shard_offset = adjust_bitblas_shard(
731
732
                    param, shard_size, shard_offset
                )
733

734
                if use_bitsandbytes_4bit:
735
736
737
738
739
740
741
                    index = list(itertools.accumulate([0] + self.output_sizes))
                    orig_offsets = {
                        str(i): (index[i], size)
                        for i, size in enumerate(self.output_sizes)
                    }
                    orig_offsets["total"] = (self.output_size, 0)
                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
742
743
                        param, orig_offsets, str(shard_id)
                    )
744

745
                loaded_weight_shard = loaded_weight.narrow(
746
747
                    output_dim, shard_offset, shard_size
                )
748
749
750
751
752
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id < len(self.output_sizes)
        if output_dim is not None:
753
            shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
754
            shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
755
            # Special case for quantization.
756
757
758
759
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
760
761
                shard_size = shard_size // param.packed_factor
                shard_offset = shard_offset // param.packed_factor
762
                # Special case for Marlin.
763
                shard_size, shard_offset = adjust_marlin_shard(
764
765
                    param, shard_size, shard_offset
                )
766
            shard_size, shard_offset = adjust_bitblas_shard(
767
768
                param, shard_size, shard_offset
            )
769

770
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
771
772
773
774
775
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

776
            if use_bitsandbytes_4bit:
777
                shard_size = loaded_weight.shape[output_dim]
778
                shard_offset = loaded_weight.shape[output_dim] * loaded_shard_id
779

780
            param_data = param_data.narrow(output_dim, shard_offset, shard_size)
781
            start_idx = self.tp_rank * shard_size
782
            if not is_sharded_weight:
783
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
784
785
786
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
787
788
                param_data, loaded_weight, loaded_shard_id
            )
789

790
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
791
792
793
794
795
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "MergedColumnParallelLinear, assume the weight is "
796
797
                    "the same for all partitions."
                )
798

799
800
801
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

802
803
804
    def _load_fused_module_from_checkpoint(
        self, param: BasevLLMParameter, loaded_weight: torch.Tensor
    ):
805
806
807
        """
        Handle special case for models where MLP layers are already
        fused on disk. In this case, we have no shard id. This function
808
        determines the shard id by splitting these layers and then calls
809
810
811
812
813
814
815
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """

        current_shard_offset = 0
816
        shard_offsets: list[tuple[int, int, int]] = []
817
818
819
820
821
822
823
824
        for i, output_size in enumerate(self.output_sizes):
            shard_offsets.append((i, current_shard_offset, output_size))
            current_shard_offset += output_size

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
825
826
827
828
829
830
831
832
833
834
835
            if (
                isinstance(param, (PackedColumnParameter, PackedvLLMParameter))
                and param.packed_dim == param.output_dim
            ):
                shard_size, shard_offset = param.adjust_shard_indexes_for_packing(
                    shard_size=shard_size, shard_offset=shard_offset
                )

            loaded_weight_shard = loaded_weight.narrow(
                param.output_dim, shard_offset, shard_size
            )
836
837
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

838
839
840
841
842
843
    def weight_loader_v2(
        self,
        param: BasevLLMParameter,
        loaded_weight: torch.Tensor,
        loaded_shard_id: Optional[int] = None,
    ):
844
        if loaded_shard_id is None:
845
            if isinstance(param, PerTensorScaleParameter):
846
                param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0)
847
                return
848
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
849
                param.load_merged_column_weight(loaded_weight=loaded_weight)
850
                return
851
            # TODO: @dsikka - move to parameter.py
852
853
854
855
856
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id < len(self.output_sizes)

857
858
        if isinstance(param, BlockQuantScaleParameter):
            assert self.quant_method is not None
859
860
861
            # Assume the weight block size has been set by quant method
            assert hasattr(self, "weight_block_size")
            weight_block_size = self.weight_block_size
862
863
864
            assert weight_block_size is not None
            block_n, _ = weight_block_size[0], weight_block_size[1]
            shard_offset = (
865
866
867
868
869
870
871
                (sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) // block_n
            ) // self.tp_size
            shard_size = (
                (self.output_sizes[loaded_shard_id] + block_n - 1)
                // block_n
                // self.tp_size
            )
872
        else:
873
            shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
874
            shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
875

876
877
878
879
880
881
882
        param.load_merged_column_weight(
            loaded_weight=loaded_weight,
            shard_id=loaded_shard_id,
            shard_offset=shard_offset,
            shard_size=shard_size,
            tp_rank=self.tp_rank,
        )
883

884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905

class QKVParallelLinear(ColumnParallelLinear):
    """Linear layers for the attention's QKV transformation.

    Linear layers for the linear transformation of the query, key, and value
    vectors in the attention layer. The weight matrix is concatenated along
    the output dimension. The layer is parallelized along the head dimension.
    When the number of key/value heads is smaller than the number of query
    heads (e.g., multi-query/grouped-query attention), the key/value head may
    be replicated while the query heads are partitioned.

    Args:
        hidden_size: input hidden state size of the transformer.
        head_size: size of each attention head.
        total_num_heads: total number of attention query heads.
        total_num_kv_heads: total number of attention key/value heads. If
                            None, assume total_num_kv_heads = total_num_heads.
        bias: If true, add bias.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
906
        quant_config: Quantization configure.
907
908
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
909
        return_bias: If true, return bias together with outputs in forward pass.
910
        disable_tp: If true, weights matrix won't be sharded through tp rank.
911
912
    """

913
914
915
916
917
918
919
920
921
922
923
924
925
    def __init__(
        self,
        hidden_size: int,
        head_size: int,
        total_num_heads: int,
        total_num_kv_heads: Optional[int] = None,
        bias: bool = True,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        *,
        return_bias: bool = True,
926
        disable_tp: bool = False,
927
    ):
928
929
930
931
932
933
934
        self.hidden_size = hidden_size
        self.head_size = head_size
        self.total_num_heads = total_num_heads
        if total_num_kv_heads is None:
            total_num_kv_heads = total_num_heads
        self.total_num_kv_heads = total_num_kv_heads
        # Divide the weight matrix along the last dimension.
935
        tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
936
937
938
        self.num_heads = divide(self.total_num_heads, tp_size)
        if tp_size >= self.total_num_kv_heads:
            self.num_kv_heads = 1
939
            self.num_kv_head_replicas = divide(tp_size, self.total_num_kv_heads)
940
941
942
943
        else:
            self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
            self.num_kv_head_replicas = 1
        input_size = self.hidden_size
944
945
946
        output_size = (
            (self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size
        )
947
948
949
        self.output_sizes = [
            self.num_heads * self.head_size * tp_size,  # q_proj
            self.num_kv_heads * self.head_size * tp_size,  # k_proj
950
            self.num_kv_heads * self.head_size * tp_size,  # v_proj
James Fleming's avatar
James Fleming committed
951
952
        ]

953
954
955
956
957
958
959
960
961
962
963
964
        super().__init__(
            input_size=input_size,
            output_size=output_size,
            bias=bias,
            gather_output=False,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            quant_config=quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
965

966
967
968
969
970
    def _get_shard_offset_mapping(self, loaded_shard_id: str):
        shard_offset_mapping = {
            "q": 0,
            "k": self.num_heads * self.head_size,
            "v": (self.num_heads + self.num_kv_heads) * self.head_size,
971
            "total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size,
972
973
974
975
976
977
978
979
980
981
982
        }
        return shard_offset_mapping.get(loaded_shard_id)

    def _get_shard_size_mapping(self, loaded_shard_id: str):
        shard_size_mapping = {
            "q": self.num_heads * self.head_size,
            "k": self.num_kv_heads * self.head_size,
            "v": self.num_kv_heads * self.head_size,
        }
        return shard_size_mapping.get(loaded_shard_id)

983
984
985
    def _load_fused_module_from_checkpoint(
        self, param: BasevLLMParameter, loaded_weight: torch.Tensor
    ):
986
        """
987
        Handle special case for models where QKV layers are already
988
        fused on disk. In this case, we have no shard id. This function
989
        determines the shard id by splitting these layers and then calls
990
991
992
993
994
995
996
997
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """
        shard_offsets = [
            # (shard_id, shard_offset, shard_size)
            ("q", 0, self.total_num_heads * self.head_size),
998
999
1000
1001
1002
1003
1004
1005
1006
1007
            (
                "k",
                self.total_num_heads * self.head_size,
                self.total_num_kv_heads * self.head_size,
            ),
            (
                "v",
                (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
                self.total_num_kv_heads * self.head_size,
            ),
1008
1009
1010
1011
1012
1013
        ]

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
            if (
                isinstance(param, (PackedColumnParameter, PackedvLLMParameter))
                and param.packed_dim == param.output_dim
            ):
                shard_size, shard_offset = param.adjust_shard_indexes_for_packing(
                    shard_size=shard_size, shard_offset=shard_offset
                )

            loaded_weight_shard = loaded_weight.narrow(
                param.output_dim, shard_offset, shard_size
            )
1025
1026
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

1027
1028
1029
1030
1031
1032
    def weight_loader_v2(
        self,
        param: BasevLLMParameter,
        loaded_weight: torch.Tensor,
        loaded_shard_id: Optional[str] = None,
    ):
1033
        if loaded_shard_id is None:  # special case for certain models
1034
            if isinstance(param, PerTensorScaleParameter):
1035
1036
1037
                param.load_qkv_weight(
                    loaded_weight=loaded_weight, shard_id=0, tp_rank=self.tp_rank
                )
1038
                return
1039
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
1040
                param.load_qkv_weight(loaded_weight=loaded_weight, tp_rank=self.tp_rank)
1041
                return
1042
            # TODO: @dsikka - move to parameter.py
1043
1044
1045
1046
1047
1048
1049
1050
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id in ["q", "k", "v"]

        shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
        shard_size = self._get_shard_size_mapping(loaded_shard_id)

1051
1052
1053
        # Note(simon): This is needed for Qwen3's fp8 quantization.
        if isinstance(param, BlockQuantScaleParameter):
            assert self.quant_method is not None
1054
1055
1056
1057
            # Assume the weight block size has been set by quant method
            assert hasattr(self, "weight_block_size")
            weight_block_size = self.weight_block_size
            assert weight_block_size is not None
1058
1059
1060
1061
            block_n, _ = weight_block_size[0], weight_block_size[1]
            shard_offset = (shard_offset + block_n - 1) // block_n
            shard_size = (shard_size + block_n - 1) // block_n

1062
1063
1064
1065
1066
1067
1068
1069
        param.load_qkv_weight(
            loaded_weight=loaded_weight,
            num_heads=self.num_kv_head_replicas,
            shard_id=loaded_shard_id,
            shard_offset=shard_offset,
            shard_size=shard_size,
            tp_rank=self.tp_rank,
        )
1070

1071
1072
1073
1074
1075
1076
    def weight_loader(
        self,
        param: Parameter,
        loaded_weight: torch.Tensor,
        loaded_shard_id: Optional[str] = None,
    ):
1077
1078
1079
1080
        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
1081
        if is_gguf_weight_type:
1082
            idx_map = {"q": 0, "k": 1, "v": 2}
1083
1084
1085
1086
            if loaded_shard_id is not None:
                param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
1087
                param.shard_weight_type = {k: loaded_weight.item() for k in idx_map}
1088
1089
            return

1090
1091
        if is_gguf_weight:
            output_dim = getattr(param, "output_dim", None)
1092
1093
            shard_size = loaded_weight.size(output_dim) // self.tp_size
            start_idx = self.tp_rank * shard_size
1094

1095
            if loaded_shard_id is not None:
1096
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
1097
1098
1099
1100
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
1101

1102
1103
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
1104

1105
1106
        # Special case for per-tensor scales in fused case.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
1107

1108
        if loaded_shard_id is None:
1109
1110
            # Loaded weight is already fused on disk (qkv).
            # (e.g., Phi-3's qkv_proj).
1111
            if output_dim is None:
1112
                if needs_scalar_to_array:
1113
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
1114
1115
                        param_data, loaded_weight, 0
                    )
1116

1117
1118
1119
1120
1121
1122
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            shard_offsets = [
                # (shard_id, shard_offset, shard_size)
                ("q", 0, self.total_num_heads * self.head_size),
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
                (
                    "k",
                    self.total_num_heads * self.head_size,
                    self.total_num_kv_heads * self.head_size,
                ),
                (
                    "v",
                    (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
                    self.total_num_kv_heads * self.head_size,
                ),
1133
            ]
1134
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1135

1136
1137
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
1138
                # Special case for Quantized Weights.
1139
1140
1141
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
1142
1143
                    shard_size = shard_size // param.packed_factor
                    shard_offset = shard_offset // param.packed_factor
1144

1145
                    # Special case for Marlin.
1146
                    shard_size, shard_offset = adjust_marlin_shard(
1147
1148
                        param, shard_size, shard_offset
                    )
1149

1150
1151
1152
                if use_bitsandbytes_4bit:
                    orig_qkv_offsets = {
                        "q": (0, self.total_num_heads * self.head_size),
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
                        "k": (
                            self.total_num_heads * self.head_size,
                            self.total_num_kv_heads * self.head_size,
                        ),
                        "v": (
                            (self.total_num_heads + self.total_num_kv_heads)
                            * self.head_size,
                            self.total_num_kv_heads * self.head_size,
                        ),
                        "total": (
                            (self.total_num_heads + 2 * self.total_num_kv_heads)
                            * self.head_size,
                            0,
                        ),
1167
1168
1169
                    }

                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
1170
1171
                        param, orig_qkv_offsets, shard_id
                    )
1172

1173
                loaded_weight_shard = loaded_weight.narrow(
1174
1175
                    output_dim, shard_offset, shard_size
                )
1176
1177
1178
1179
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id in ["q", "k", "v"]
1180
1181

        # If output dim is defined, use the default loading process.
1182
1183
1184
1185
1186
1187
1188
1189
        if output_dim is not None:
            if loaded_shard_id == "q":
                shard_offset = 0
                shard_size = self.num_heads * self.head_size
            elif loaded_shard_id == "k":
                shard_offset = self.num_heads * self.head_size
                shard_size = self.num_kv_heads * self.head_size
            elif loaded_shard_id == "v":
1190
                shard_offset = (self.num_heads + self.num_kv_heads) * self.head_size
1191
                shard_size = self.num_kv_heads * self.head_size
1192
            # Special case for Quantized Weights.
1193
1194
1195
1196
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
1197
1198
                shard_size = shard_size // param.packed_factor
                shard_offset = shard_offset // param.packed_factor
1199

1200
                # Special case for Marlin.
1201
                shard_size, shard_offset = adjust_marlin_shard(
1202
1203
                    param, shard_size, shard_offset
                )
1204

1205
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1206
1207
1208
1209
1210
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

1211
            if use_bitsandbytes_4bit:
1212
1213
                orig_qkv_offsets = {
                    "q": (0, self.num_heads * self.head_size),
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
                    "k": (
                        self.num_heads * self.head_size,
                        self.num_kv_heads * self.head_size,
                    ),
                    "v": (
                        (self.num_heads + self.num_kv_heads) * self.head_size,
                        self.num_kv_heads * self.head_size,
                    ),
                    "total": (
                        (self.num_heads + 2 * self.num_kv_heads) * self.head_size,
                        0,
                    ),
1226
                }
1227
                shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
1228
1229
                    param, orig_qkv_offsets, loaded_shard_id
                )
1230

1231
            param_data = param_data.narrow(output_dim, shard_offset, shard_size)
1232
            if loaded_shard_id == "q":
1233
                shard_id = self.tp_rank
1234
            else:
1235
                shard_id = self.tp_rank // self.num_kv_head_replicas
1236
            start_idx = shard_id * shard_size
1237

1238
            if not is_sharded_weight:
1239
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
1240

1241
1242
1243
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
1244
1245
                param_data, loaded_weight, loaded_shard_id
            )
1246
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
1247
1248
1249
1250
1251
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "QKVParallelLinear, assume the weight is the same "
1252
1253
                    "for all partitions."
                )
1254

1255
1256
1257
1258
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)


1259
@CustomOp.register("row_parallel_linear")
1260
class RowParallelLinear(LinearBase):
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
    """Linear layer with row parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its first dimension and X along its second dimension as:
               -   -
              | A_1 |
              | .   |
          A = | .   |        X = [X_1, ..., X_p]
              | .   |
              | A_p |
               -   -
    Arguments:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias. Note that bias is not parallelized.
        input_is_parallel: If true, we assume that the input is already
                           split across the GPUs and we do not split
                           again.
        skip_bias_add: This was added to enable performance optimization where
                       bias can be fused with other element-wise operations.
                       We skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
1283
1284
1285
        reduce_results: If true, call all-reduce on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y = X_iA_i
1286
        quant_config: Quantization configure.
1287
1288
1289
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.down_proj)
        return_bias: If true, return bias together with outputs in forward pass.
1290
        disable_tp: If true, weights matrix won't be sharded through tp rank.
1291
1292
    """

1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        input_is_parallel: bool = True,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        reduce_results: bool = True,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        *,
        return_bias: bool = True,
1306
        disable_tp: bool = False,
1307
    ):
1308
        # Divide the weight matrix along the first dimension.
1309
1310
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
1311
1312
1313
1314
        self.input_size_per_partition = divide(input_size, self.tp_size)
        self.output_size_per_partition = output_size
        self.output_partition_sizes = [output_size]

1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
        super().__init__(
            input_size,
            output_size,
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
1325

1326
1327
1328
        self.input_is_parallel = input_is_parallel
        self.reduce_results = reduce_results

1329
        assert self.quant_method is not None
1330
1331
1332
        self.quant_method.create_weights(
            layer=self,
            input_size_per_partition=self.input_size_per_partition,
1333
            output_partition_sizes=self.output_partition_sizes,
1334
1335
1336
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
1337
            weight_loader=(
1338
1339
1340
1341
1342
                self.weight_loader_v2
                if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
                else self.weight_loader
            ),
        )
1343
        if not reduce_results and (bias and not skip_bias_add):
1344
1345
1346
1347
            raise ValueError(
                "When not reduce the results, adding bias to the "
                "results can lead to incorrect results"
            )
1348
1349

        if bias:
1350
1351
1352
1353
1354
1355
1356
1357
            self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype))
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
1358
1359
        else:
            self.register_parameter("bias", None)
1360
        self.update_param_tp_status()
1361
1362
1363

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        input_dim = getattr(param, "input_dim", None)
1364
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1365
1366
1367
1368
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
1369
1370
1371
1372
1373
1374
1375
1376
1377

        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
1378
1379
            weight_shape = list(loaded_weight.shape)
            if input_dim:
1380
                weight_shape[input_dim] = weight_shape[input_dim] // self.tp_size
1381
            param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
1382

1383
        param_data = param.data
1384
        if input_dim is not None and not is_sharded_weight:
1385
            shard_size = param_data.shape[input_dim]
1386
            start_idx = self.tp_rank * shard_size
1387
            loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)
1388

1389
1390
1391
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
1392
1393
            loaded_weight = loaded_weight.reshape(1)

1394
1395
1396
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

1397
    def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor):
1398
1399
1400
1401
1402
1403
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)

1404
1405
        param.load_row_parallel_weight(loaded_weight=loaded_weight)

1406
    def forward(
1407
1408
        self,
        input_,
1409
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
1410
1411
1412
1413
        if self.input_is_parallel:
            input_parallel = input_
        else:
            splitted_input = split_tensor_along_last_dim(
1414
1415
                input_, num_partitions=self.tp_size
            )
1416
            input_parallel = splitted_input[self.tp_rank].contiguous()
1417
1418

        # Matrix multiply.
1419
        assert self.quant_method is not None
1420
1421
1422
        # Only fuse bias add into GEMM for rank 0 (this ensures that
        # bias will not get added more than once in TP>1 case)
        bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
1423
1424
        output_parallel = self.quant_method.apply(self, input_parallel, bias_)

1425
        if self.reduce_results and self.tp_size > 1:
1426
            output = tensor_model_parallel_all_reduce(output_parallel)
1427
        else:
1428
1429
1430
            output = output_parallel

        output_bias = self.bias if self.skip_bias_add else None
1431

1432
1433
        if not self.return_bias:
            return output
1434
        return output, output_bias
1435
1436

    def extra_repr(self) -> str:
1437
        s = f"in_features={self.input_size_per_partition}"
1438
1439
1440
1441
1442
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={self.tp_size}"
        s += f", reduce_results={self.reduce_results}"
        return s
1443
1444


1445
@CustomOp.register("qkv_cross_parallel_linear")
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
class QKVCrossParallelLinear(LinearBase):
    """Linear layers for efficient cross-attention's QKV transformation.

    Args:
        hidden_size: input hidden state size of the transformer.
        head_size: size of each attention head.
        total_num_heads: total number of attention query heads.
        total_num_kv_heads: total number of attention key/value heads. If
                            None, assume total_num_kv_heads = total_num_heads.
        bias: If true, add bias.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
        quant_config: Quantization configure.
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
    """
1464

1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
    def __init__(
        self,
        hidden_size: int,
        head_size: int,
        total_num_heads: int,
        total_num_kv_heads: Optional[int] = None,
        bias: bool = True,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
1477
1478
1479
        # input_size and output_size are not used, just for alignment
        input_size = hidden_size
        output_size = (total_num_heads + (total_num_kv_heads or 0)) * head_size
1480
1481
1482
1483
1484
1485
1486
1487
        super().__init__(
            input_size=input_size,
            output_size=output_size,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            quant_config=quant_config,
            prefix=prefix,
        )
1488
1489
1490

        self.quant_config = quant_config

1491
        # Empty placeholders for loading as a single module.
1492
1493
        placeholder_size = 0
        assert self.quant_method is not None
1494
1495
1496
1497
1498
1499
1500
1501
1502
        self.quant_method.create_weights(
            self,
            placeholder_size,
            [placeholder_size],
            placeholder_size,
            placeholder_size,
            self.params_dtype,
            weight_loader=self.weight_loader,
        )
1503

1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
        # Use a dictionary to avoid submodules parameters auto-registration:
        # drop-in replacement for a `QKVParallelLinear` module.
        self.proj = dict()
        self.proj["q_proj_decoder"] = ColumnParallelLinear(
            input_size=hidden_size,
            output_size=total_num_heads * head_size,
            bias=bias,
            quant_config=quant_config,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
1514
1515
            prefix=f"{prefix}.q_proj_decoder",
        )
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525

        self.proj["kv_proj_encoder"] = QKVParallelLinear(
            hidden_size=hidden_size,
            head_size=head_size,
            total_num_heads=0,
            total_num_kv_heads=total_num_kv_heads,
            bias=bias,
            quant_config=quant_config,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
1526
1527
            prefix=f"{prefix}.kv_proj_encoder",
        )
1528
1529

        # `kv_proj_encoder.num_kv_heads` accounts for sharding with tp>1.
1530
        self.q_size = self.q_proj_decoder.output_size_per_partition
1531
1532
1533
1534
        self.kv_size = self.kv_proj_encoder.num_kv_heads * head_size

        if bias:
            self.bias = torch.nn.Parameter()
1535
1536
1537
1538
1539
1540
1541
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader_v1,
                },
            )
1542
1543
        else:
            self.bias = None
1544

1545
1546
1547
1548
1549
    def process_weights_after_loading(self):
        for layer in self.proj.values():
            if self.quant_method is not None:
                self.quant_method.process_weights_after_loading(layer)

1550
    @property
1551
1552
1553
    def q_proj_decoder(self) -> ColumnParallelLinear:
        layer = self.proj["q_proj_decoder"]
        for name, param in self.named_parameters():
1554
1555
            target_param = getattr(layer, name, None)
            if target_param is not None:
1556
                self.sync_weight_attrs(param, target_param, mode="q_proj_decoder")
1557
        return layer
1558
1559

    @property
1560
1561
1562
    def kv_proj_encoder(self) -> QKVParallelLinear:
        layer = self.proj["kv_proj_encoder"]
        for name, param in self.named_parameters():
1563
1564
            target_param = getattr(layer, name, None)
            if target_param is not None:
1565
                self.sync_weight_attrs(param, target_param, mode="kv_proj_encoder")
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
        return layer

    def sync_weight_attrs(
        self,
        src_param: nn.Parameter,
        tgt_param: nn.Parameter,
        mode: Literal["q_proj_decoder", "kv_proj_encoder"],
    ):
        missing_attrs_dict = {
            k: getattr(src_param, k)
1576
            for k in (set(vars(src_param).keys()) - set(vars(tgt_param).keys()))
1577
1578
        }
        # TODO(Isotr0py): handle bitsandbytes 8bit
1579
1580
        use_bitsandbytes_4bit = getattr(src_param, "use_bitsandbytes_4bit", False)
        if missing_attrs_dict and use_bitsandbytes_4bit:
1581
            q_proj_attrs, kv_proj_attrs = left_shift_bitsandbytes_4bit_shard(
1582
1583
                missing_attrs_dict
            )
1584
1585
1586
1587
1588
1589
            if mode == "q_proj_decoder":
                set_weight_attrs(tgt_param, q_proj_attrs)
            elif mode == "kv_proj_encoder":
                set_weight_attrs(tgt_param, kv_proj_attrs)
        else:
            set_weight_attrs(tgt_param, missing_attrs_dict)
1590

1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
    def _is_same_param(
        self,
        src_param: torch.nn.Parameter,
        map_param: torch.nn.Parameter,
    ) -> bool:
        """Check if two parameters are exactly pointing to same things."""
        # ignore weight_loader because it's always different
        key_to_ignore = ["weight_loader", "_weight_loader"]
        has_same_type_name = type(src_param) is type(map_param)
        src_param_attrs = {
1601
            k: v for k, v in src_param.__dict__.items() if k not in key_to_ignore
1602
1603
        }
        map_param_attrs = {
1604
            k: v for k, v in map_param.__dict__.items() if k not in key_to_ignore
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
        }
        has_same_attrs = src_param_attrs == map_param_attrs
        return has_same_type_name and has_same_attrs

    def select_proj_params(
        self,
        layer: nn.Module,
        param: nn.Parameter,
    ) -> nn.Parameter:
        """
1615
        Given the placeholder param,
1616
1617
1618
        return the corresponding param in the proj layers.
        """
        target_param_list = [
1619
            v for _, v in layer.named_parameters() if self._is_same_param(param, v)
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
        ]
        assert len(target_param_list) == 1
        target_param = target_param_list[0]
        return target_param

    def forward(  # type: ignore[override]
        self,
        decoder_hidden_states: torch.Tensor,
        encoder_hidden_states: torch.Tensor,
    ) -> tuple[torch.Tensor, ...]:
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
        q, _ = self.q_proj_decoder(decoder_hidden_states)
        if encoder_hidden_states is None:
            # Encoder KV already cached.
            k = None
            v = None
        else:
            # Prefill phase, encoder KV cached here.
            kv_enc, _ = self.kv_proj_encoder(encoder_hidden_states)
            # Split kv in half
            k, v = kv_enc.split(self.kv_size, dim=-1)
        return q, k, v

1642
1643
1644
1645
1646
1647
    def weight_loader_v1(
        self,
        param: torch.nn.Parameter,
        loaded_weight: torch.Tensor,
        loaded_shard_id: Optional[str] = None,
    ):
1648
1649
        # just like all other parameters, does not yet
        # support loading bias with weight_loader_v2
1650
        layer = self.q_proj_decoder if loaded_shard_id == "q" else self.kv_proj_encoder
1651
        target_param = self.select_proj_params(layer, param)
1652
        shard_id_args = (loaded_shard_id,) if loaded_shard_id != "q" else ()
1653
1654
        layer.weight_loader(target_param, loaded_weight, *shard_id_args)

1655
1656
1657
1658
1659
1660
1661
    def weight_loader(
        self,
        param: torch.nn.Parameter,
        loaded_weight: torch.Tensor,
        loaded_shard_id: Optional[str] = None,
    ):
        layer = self.q_proj_decoder if loaded_shard_id == "q" else self.kv_proj_encoder
1662
        target_param = self.select_proj_params(layer, param)
1663
        shard_id_args = (loaded_shard_id,) if loaded_shard_id != "q" else ()
1664
1665
1666
1667
        if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED:
            layer.weight_loader_v2(target_param, loaded_weight, *shard_id_args)
        else:
            layer.weight_loader(target_param, loaded_weight, *shard_id_args)
1668
1669
1670

    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
1671
        s += f", q_size={self.q_size}"
1672
1673
1674
1675
1676
        s += f", kv_size={self.kv_size}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={get_tensor_model_parallel_world_size()}"
        s += ", gather_output=False"
        return s