linear.py 67.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
from abc import abstractmethod
6
from typing import Any, Literal, Optional, Union
7
8

import torch
9
import torch.nn as nn
10
from torch.nn.parameter import Parameter, UninitializedParameter
11

12
13
14
15
16
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size,
                              split_tensor_along_last_dim,
                              tensor_model_parallel_all_gather,
                              tensor_model_parallel_all_reduce)
17
from vllm.logger import init_logger
18
from vllm.model_executor.custom_op import CustomOp
19
20
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig, QuantizeMethodBase)
21
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
22
# yapf: disable
23
from vllm.model_executor.parameter import (BasevLLMParameter,
24
                                           BlockQuantScaleParameter,
25
                                           ModelWeightParameter,
26
                                           PackedColumnParameter,
27
                                           PackedvLLMParameter,
28
29
                                           PerTensorScaleParameter,
                                           RowvLLMParameter)
30
# yapf: enable
31
from vllm.model_executor.utils import set_weight_attrs
32
from vllm.platforms import current_platform
33
from vllm.utils import GiB_bytes
34
35
36

logger = init_logger(__name__)

37
WEIGHT_LOADER_V2_SUPPORTED = [
38
    "UnquantizedLinearMethod",
39
    "CompressedTensorsLinearMethod",
40
    "CompressedTensorsLinearTransformMethod",
41
42
    "BitBLASLinearMethod",
    "GPTQBitBLASLinearMethod",
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
    "AWQMarlinLinearMethod",
    "AWQLinearMethod",
    "GPTQMarlinLinearMethod",
    "Fp8LinearMethod",
    "MarlinLinearMethod",
    "GPTQMarlin24LinearMethod",
    "TPUInt8LinearMethod",
    "GPTQLinearMethod",
    "FBGEMMFp8LinearMethod",
    "ModelOptFp8LinearMethod",
    "IPEXAWQLinearMethod",
    "IPEXGPTQLinearMethod",
    "HQQMarlinMethod",
    "QuarkLinearMethod",
    "ModelOptNvFp4LinearMethod",
58
    "PetitNvFp4LinearMethod",
59
]
60

61

62
63
64
65
66
67
68
69
70
def adjust_bitblas_shard(param, shard_size, shard_offset):
    bitblas_tile_size = getattr(param, "bitblas_tile_size", None)
    if bitblas_tile_size is not None:
        return (shard_size // bitblas_tile_size,
                shard_offset // bitblas_tile_size)

    return shard_size, shard_offset


71
72
73
74
75
76
77
78
def adjust_marlin_shard(param, shard_size, shard_offset):
    marlin_tile_size = getattr(param, "marlin_tile_size", None)
    if marlin_tile_size is None:
        return shard_size, shard_offset

    return shard_size * marlin_tile_size, shard_offset * marlin_tile_size


79
def adjust_bitsandbytes_4bit_shard(param: Parameter,
80
81
                                   shard_offsets: dict[str, tuple[int, int]],
                                   loaded_shard_id: str) -> tuple[int, int]:
82
83
    """Adjust the quantization offsets and sizes for BitsAndBytes sharding."""

84
85
    total, _ = shard_offsets["total"]
    orig_offset, orig_size = shard_offsets[loaded_shard_id]
86
87
88
89
90
91
92
93

    quantized_total = param.data.shape[0]
    quantized_offset = orig_offset * quantized_total // total
    quantized_size = orig_size * quantized_total // total

    return quantized_size, quantized_offset


94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
    """For fused modules (QKV and MLP) we have an array of length
    N that holds 1 scale for each "logical" matrix. So the param
    is an array of length N. The loaded_weight corresponds to 
    one of the shards on disk. Here, we slice the param based on 
    the shard_id for loading.
    """
    qkv_idxs = {"q": 0, "k": 1, "v": 2}

    if isinstance(shard_id, str):
        shard_id = qkv_idxs[shard_id]
    elif not isinstance(shard_id, int):
        raise ValueError(f"Unknown Shard Id {shard_id}")

    # AutoFP8 scales do not have a shape
    # compressed-tensors scales do have a shape
    if len(loaded_weight.shape) != 0:
        assert loaded_weight.shape[0] == 1
        loaded_weight = loaded_weight[0]

    return param[shard_id], loaded_weight


117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
# TODO(Isotr0py): We might need a more flexible structure to handle
# bitsandbytes shard offsets.
def left_shift_bitsandbytes_4bit_shard(bnb_weight_attrs: dict[str, Any]):
    """
    Separate the BitsAndBytes 4-bit shard.

    For example, given bnb weight attributes as below:
    {
        'bnb_shard_offsets': array([0, 4, 8, 16]), 
        'bnb_quant_state': {0: ..., 1: ..., 2: ...},
    }

    The function will return:
    {
        'bnb_shard_offsets': array([0, 4]), 
        'bnb_quant_state': {0: ...},
    }
    and
    {
        'bnb_shard_offsets': array([0, 4, 12]),
        'bnb_quant_state': {0: ..., 1: ...},
    }
    """
    shard_offsets = bnb_weight_attrs["bnb_shard_offsets"]
    offset_l = shard_offsets[:2]
    offset_r = shard_offsets[1:] - shard_offsets[1]
    quant_state_l = {0: bnb_weight_attrs["bnb_quant_state"][0]}
    quant_state_r = {
        i - 1: bnb_weight_attrs["bnb_quant_state"][i]
        for i in range(1,
                       len(shard_offsets) - 1)
    }
    left = dict(bnb_shard_offsets=offset_l, bnb_quant_state=quant_state_l)
    right = dict(bnb_shard_offsets=offset_r, bnb_quant_state=quant_state_r)
    return left, right


154
class LinearMethodBase(QuantizeMethodBase):
155
156
157
    """Base class for different (maybe quantized) linear methods."""

    @abstractmethod
158
159
    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
160
                       output_partition_sizes: list[int], input_size: int,
161
162
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
163
164
        """Create weights for a linear layer. 
           The weights will be set as attributes of the layer.
165

166
167
168
169
170
171
172
173
174
175
        Args:
            layer: The layer that is using the LinearMethodBase factory.
            input_size_per_partition: Size of the weight input dim on rank X.
            output_partition_sizes: Sizes of the output dim of each logical 
                weight on rank X. E.g., output_partition_sizes for QKVLinear
                is a list contains the width of Wq, Wk, Wv on rank X.
            input_size: Size of the input dim of the weight across all ranks.
            output_size: Size of the output dim of the weight across all ranks.
            params_dtype: Datatype of the parameters.
        """
176
177
178
        raise NotImplementedError

    @abstractmethod
179
180
181
182
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
183
184
        """Apply the weights in layer to the input tensor.
        Expects create_weights to have been called before on the layer."""
185
186
187
188
        raise NotImplementedError


class UnquantizedLinearMethod(LinearMethodBase):
189
    """Linear method without quantization."""
190

191
192
    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
193
                       output_partition_sizes: list[int], input_size: int,
194
195
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
196
197
198
199
200
        # This method creates unquantized linear weights.
        # The weights are not quantized, and they are not sharded.
        # The amount of memory allocated for the weights is
        # sum(output_partition_sizes) * input_size_per_partition.
        try:
201
202
203
204
205
206
207
208
            weight_loader = extra_weight_attrs.pop("weight_loader")
            weight = ModelWeightParameter(data=torch.empty(
                sum(output_partition_sizes),
                input_size_per_partition,
                dtype=params_dtype),
                                          input_dim=1,
                                          output_dim=0,
                                          weight_loader=weight_loader)
209
210
211
212
213
214
215
216
217
218
219
220
        except torch.cuda.OutOfMemoryError as e:
            logger.error("Failed to create unquantized linear weights: %s", e)
            if torch.cuda.is_available():
                logger.debug("CUDA device: %s", torch.cuda.current_device())
                logger.debug("Allocated: %.2f GiB",
                             torch.cuda.memory_allocated() / GiB_bytes)
                logger.debug("Reserved: %.2f GiB",
                             torch.cuda.memory_reserved() / GiB_bytes)
            raise RuntimeError(
                "Failed to create unquantized linear weights. "
                "This may be caused by insufficient memory to allocate "
                "the weight.") from e
221

222
223
        layer.register_parameter("weight", weight)
        set_weight_attrs(weight, extra_weight_attrs)
224

225
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
226
227
228
229
        if current_platform.is_cpu():
            from vllm.model_executor.layers.utils import (
                dispatch_cpu_unquantized_gemm)
            dispatch_cpu_unquantized_gemm(layer, remove_weight=True)
230

231
232
233
234
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
235

236
        return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
237
238


239
class LinearBase(CustomOp):
240
    """Base linear layer.
241
242
243
244
245
246

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
247
        quant_config: Quantization configure.
248
        prefix: Prefix for parameter names.
249
        return_bias: If true, return bias together with outputs in forward pass.
250
        disable_tp: If true, tensor parallelism will be disabled for this layer.
251
252
253
254
255
256
257
258
    """

    def __init__(
        self,
        input_size: int,
        output_size: int,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
259
        quant_config: Optional[QuantizationConfig] = None,
260
        prefix: str = "",
261
262
        *,
        return_bias: bool = True,
263
        disable_tp: bool = False,
264
265
266
267
268
269
270
271
272
273
    ):
        super().__init__()

        # Keep input parameters
        self.input_size = input_size
        self.output_size = output_size
        self.skip_bias_add = skip_bias_add
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype
274
275
        self.quant_config = quant_config
        self.prefix = prefix
276
        if quant_config is None:
277
278
            self.quant_method: Optional[
                QuantizeMethodBase] = UnquantizedLinearMethod()
279
        else:
280
281
            self.quant_method = quant_config.get_quant_method(self,
                                                              prefix=prefix)
282
        self.return_bias = return_bias
283
284
285
286
287
288
        self.disable_tp = disable_tp
        self.tp_rank = (get_tensor_model_parallel_rank()
                        if not disable_tp else 0)
        self.tp_size = (get_tensor_model_parallel_world_size()
                        if not disable_tp else 1)

289
    def update_param_tp_status(self):
290
291
292
293
        for param in self.parameters():
            if isinstance(param, BasevLLMParameter):
                param.tp_rank = self.tp_rank
                param.tp_size = self.tp_size
294
295


296
@CustomOp.register("replicated_linear")
297
298
299
300
301
302
303
304
305
306
class ReplicatedLinear(LinearBase):
    """Replicated linear layer.

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        bias: If true, add bias.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
        quant_config: Quantization configure.
307
308
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
309
        return_bias: If true, return bias together with outputs in forward pass.
310
        disable_tp: Take no effect for replicated linear layers.
311
312
    """

313
314
315
316
317
318
319
320
321
322
323
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        *,
        return_bias: bool = True,
324
        disable_tp: bool = False,
325
    ):
326
327
328
329
330
331
        # If MergedReplicatedLinear, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = self.output_sizes
        else:
            self.output_partition_sizes = [output_size]

332
333
334
335
336
        super().__init__(input_size,
                         output_size,
                         skip_bias_add,
                         params_dtype,
                         quant_config,
337
                         prefix=prefix,
338
339
                         return_bias=return_bias,
                         disable_tp=disable_tp)
340

341
342
        # All the linear layer supports quant method.
        assert self.quant_method is not None
343
        self.quant_method.create_weights(self,
344
345
                                         self.input_size,
                                         self.output_partition_sizes,
346
347
348
                                         self.input_size,
                                         self.output_size,
                                         self.params_dtype,
349
                                         weight_loader=self.weight_loader)
350

351
352
        if bias:
            self.bias = Parameter(
353
                torch.empty(self.output_size, dtype=self.params_dtype))
354
355
356
357
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
358
359
360
        else:
            self.register_parameter("bias", None)

361
362
363
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        # If the weight on disk does not have a shape, give it one
        # (such scales for AutoFp8).
364
365
366
367
368
369
370
371
372
373
374
        # Special case for GGUF

        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
            param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)

375
376
377
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)

378
379
380
        assert param.size() == loaded_weight.size(), (
            f"Tried to load weights of size {loaded_weight.size()}"
            f"to a parameter of size {param.size()}")
381
382
        param.data.copy_(loaded_weight)

383
    def forward(
384
385
        self,
        x: torch.Tensor,
386
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
387
        bias = self.bias if not self.skip_bias_add else None
388
        assert self.quant_method is not None
389

390
        output = self.quant_method.apply(self, x, bias)
391
        output_bias = self.bias if self.skip_bias_add else None
392

393
394
        if not self.return_bias:
            return output
395
396
        return output, output_bias

397
398
399
400
401
402
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        return s

403

404
@CustomOp.register("column_parallel_linear")
405
class ColumnParallelLinear(LinearBase):
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
    """Linear layer with column parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its second dimension as A = [A_1, ..., A_p].

    Args:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y_i = XA_i
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
422
        quant_config: Quantization configure.
James Fleming's avatar
James Fleming committed
423
424
        output_sizes: list of output sizes packed into one output, like for QKV
                       the list would be size 3.
425
        prefix: The name of the layer in the state dict, including all parents
426
                        (e.g. model.layers.0.qkv_proj) 
427
428
        return_bias: If true, return bias together with outputs in forward pass.
        disable_tp: If true, weights matrix won't be sharded through tp rank.
429
430
    """

431
432
433
434
435
436
437
438
439
440
441
442
443
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
        output_sizes: Optional[list[int]] = None,
        prefix: str = "",
        *,
        return_bias: bool = True,
444
        disable_tp: bool = False,
445
    ):
446
        # Divide the weight matrix along the last dimension.
447
448
449
450
        self.tp_rank = (get_tensor_model_parallel_rank()
                        if not disable_tp else 0)
        self.tp_size = (get_tensor_model_parallel_world_size()
                        if not disable_tp else 1)
451
452
        self.input_size_per_partition = input_size
        self.output_size_per_partition = divide(output_size, self.tp_size)
453
454
455
456
        self.output_partition_sizes = [self.output_size_per_partition]
        # If QKV or MergedColumn, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = [
457
                divide(output_size, self.tp_size)
458
459
460
                for output_size in self.output_sizes
            ]

461
462
463
464
465
466
        super().__init__(input_size,
                         output_size,
                         skip_bias_add,
                         params_dtype,
                         quant_config,
                         prefix,
467
468
                         return_bias=return_bias,
                         disable_tp=disable_tp)
469
470
471

        self.gather_output = gather_output

James Fleming's avatar
James Fleming committed
472
473
        if output_sizes is None:
            output_sizes = [output_size]
474

475
        assert self.quant_method is not None
476
477
        self.quant_method.create_weights(
            layer=self,
478
            input_size_per_partition=self.input_size_per_partition,
479
480
481
482
            output_partition_sizes=self.output_partition_sizes,
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
483
484
            weight_loader=(
                self.weight_loader_v2 if self.quant_method.__class__.__name__
485
                in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
486
487
488
489
490
491
492
493
494
495
        if bias:
            self.bias = Parameter(
                torch.empty(self.output_size_per_partition,
                            dtype=params_dtype))
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
        else:
            self.register_parameter("bias", None)
496
        self.update_param_tp_status()
497
498

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
499

500
        output_dim = getattr(param, "output_dim", None)
501

502
503
504
505
506
507
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

508
509
510
511
512
513
514
515
        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
516
517
            final_shape = list(loaded_weight.shape)
            if output_dim is not None:
518
519
520
                assert final_shape[output_dim] % self.tp_size == 0
                final_shape[output_dim] = (final_shape[output_dim] //
                                           self.tp_size)
521
            param.materialize(final_shape, dtype=loaded_weight.dtype)
522

523
        param_data = param.data
524
        if output_dim is not None and not is_sharded_weight:
525
            shard_size = param_data.shape[output_dim]
526
            start_idx = self.tp_rank * shard_size
527
528
            loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                 shard_size)
529
530
531
532
533

        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)
534

535
536
537
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

538
539
    def weight_loader_v2(self, param: BasevLLMParameter,
                         loaded_weight: torch.Tensor):
540
541
542
543
544
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)
545
546
        param.load_column_parallel_weight(loaded_weight=loaded_weight)

547
    def forward(
548
549
        self,
        input_,
550
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
551
552
553
        bias = self.bias if not self.skip_bias_add else None

        # Matrix multiply.
554
        assert self.quant_method is not None
555
        output_parallel = self.quant_method.apply(self, input_, bias)
556

557
        if self.gather_output and self.tp_size > 1:
558
559
560
561
562
            # All-gather across the partitions.
            output = tensor_model_parallel_all_gather(output_parallel)
        else:
            output = output_parallel
        output_bias = self.bias if self.skip_bias_add else None
563
564
        if not self.return_bias:
            return output
565
566
        return output, output_bias

567
568
569
570
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size_per_partition}"
        s += f", bias={self.bias is not None}"
571
        s += f", tp_size={self.tp_size}"
572
573
574
        s += f", gather_output={self.gather_output}"
        return s

575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593

class MergedColumnParallelLinear(ColumnParallelLinear):
    """Packed linear layers with column parallelism.

    Similar to ColumnParallelLinear, but the weight matrix is concatenated
    along the output dimension. When the weight matrix is loaded, the
    different partitions are sharded separately.

    Args:
        input_size: input dimension of the linear layer.
        output_sizes: list of output dimensions of the linear layer.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make the output
                       available to all GPUs, otherwise, every GPU will have
                       its own output.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
594
        quant_config: Quantization configure.
595
596
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
597
        return_bias: If true, return bias together with outputs in forward pass.
598
599
        disable_tp: If true, all weights matrix won't be sharded, this layer
                    will be treated as a "Replicated" MergedLinear.
600
601
    """

602
603
604
605
606
607
608
609
610
611
612
613
    def __init__(
        self,
        input_size: int,
        output_sizes: list[int],
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        *,
        return_bias: bool = True,
614
        disable_tp: bool = False,
615
    ):
616
        self.output_sizes = output_sizes
617
618
619
620
        self.tp_size = (get_tensor_model_parallel_world_size()
                        if not disable_tp else 1)
        self.tp_rank = (get_tensor_model_parallel_rank()
                        if not disable_tp else 0)
621
622
623

        assert all(output_size % self.tp_size == 0
                   for output_size in output_sizes)
624
625
626
627
628
629
        super().__init__(input_size=input_size,
                         output_size=sum(output_sizes),
                         bias=bias,
                         gather_output=gather_output,
                         skip_bias_add=skip_bias_add,
                         params_dtype=params_dtype,
630
                         quant_config=quant_config,
631
                         prefix=prefix,
632
633
                         return_bias=return_bias,
                         disable_tp=disable_tp)
634
635
636
637
638

    def weight_loader(self,
                      param: Parameter,
                      loaded_weight: torch.Tensor,
                      loaded_shard_id: Optional[int] = None):
James Fleming's avatar
James Fleming committed
639

640
641
642
643
644
        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
645
646
647
648
649
650
651
652
            if loaded_shard_id is not None:
                param.data[loaded_shard_id].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
                param.shard_weight_type = {
                    i: loaded_weight.item()
                    for i, _ in enumerate(self.output_sizes)
                }
653
654
            return

655
656
657
        if is_gguf_weight:

            output_dim = getattr(param, "output_dim", None)
658
659
            shard_size = loaded_weight.size(output_dim) // self.tp_size
            start_idx = self.tp_rank * shard_size
660

661
662
663
664
665
666
667
            if loaded_shard_id is not None:
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
668

669
670
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
671
672
        # Special case for per-tensor scale to load scalar into fused array.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
673

674
        if loaded_shard_id is None:
675
676
            # Loaded weight is already fused on disk (mlp).
            # (e.g., Phi-3's gate_up_proj).
677
            if output_dim is None:
678
                if needs_scalar_to_array:
679
680
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
                        param_data, loaded_weight, 0)
681

682
683
684
685
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            current_shard_offset = 0
686
687
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)
688
            shard_offsets: list[tuple[int, int, int]] = []
689
690
691
692
693
            for i, output_size in enumerate(self.output_sizes):
                shard_offsets.append((i, current_shard_offset, output_size))
                current_shard_offset += output_size
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
694
                # Special case for Quantization.
695
696
697
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
698
699
                    shard_size = shard_size // param.packed_factor
                    shard_offset = shard_offset // param.packed_factor
700
                    # Special case for Marlin.
701
702
                    shard_size, shard_offset = adjust_marlin_shard(
                        param, shard_size, shard_offset)
703

704
705
706
                shard_size, shard_offset = adjust_bitblas_shard(
                    param, shard_size, shard_offset)

707
                if use_bitsandbytes_4bit:
708
709
710
711
712
713
714
715
716
                    index = list(itertools.accumulate([0] + self.output_sizes))
                    orig_offsets = {
                        str(i): (index[i], size)
                        for i, size in enumerate(self.output_sizes)
                    }
                    orig_offsets["total"] = (self.output_size, 0)
                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
                        param, orig_offsets, str(shard_id))

717
718
719
720
721
722
723
                loaded_weight_shard = loaded_weight.narrow(
                    output_dim, shard_offset, shard_size)
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id < len(self.output_sizes)
        if output_dim is not None:
724
725
726
            shard_offset = (sum(self.output_sizes[:loaded_shard_id]) //
                            self.tp_size)
            shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
727
            # Special case for quantization.
728
729
730
731
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
732
733
                shard_size = shard_size // param.packed_factor
                shard_offset = shard_offset // param.packed_factor
734
                # Special case for Marlin.
735
736
                shard_size, shard_offset = adjust_marlin_shard(
                    param, shard_size, shard_offset)
737
738
            shard_size, shard_offset = adjust_bitblas_shard(
                param, shard_size, shard_offset)
739

740
741
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)
742
743
744
745
746
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

747
            if use_bitsandbytes_4bit:
748
749
750
751
                shard_size = loaded_weight.shape[output_dim]
                shard_offset = loaded_weight.shape[output_dim] * \
                    loaded_shard_id

752
753
            param_data = param_data.narrow(output_dim, shard_offset,
                                           shard_size)
754
            start_idx = self.tp_rank * shard_size
755
            if not is_sharded_weight:
756
757
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)
758
759
760
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
761
762
                param_data, loaded_weight, loaded_shard_id)

763
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
764
765
766
767
768
769
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "MergedColumnParallelLinear, assume the weight is "
                    "the same for all partitions.")
770

771
772
773
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

774
775
776
777
778
    def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
                                           loaded_weight: torch.Tensor):
        """
        Handle special case for models where MLP layers are already
        fused on disk. In this case, we have no shard id. This function
779
        determines the shard id by splitting these layers and then calls
780
781
782
783
784
785
786
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """

        current_shard_offset = 0
787
        shard_offsets: list[tuple[int, int, int]] = []
788
789
790
791
792
793
794
795
        for i, output_size in enumerate(self.output_sizes):
            shard_offsets.append((i, current_shard_offset, output_size))
            current_shard_offset += output_size

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
796
797
            if isinstance(param, (PackedColumnParameter, PackedvLLMParameter
                                  )) and param.packed_dim == param.output_dim:
798
799
                shard_size, shard_offset = \
                    param.adjust_shard_indexes_for_packing(
800
801
802
803
804
805
806
807
808
809
810
811
                    shard_size=shard_size, shard_offset=shard_offset)

            loaded_weight_shard = loaded_weight.narrow(param.output_dim,
                                                       shard_offset,
                                                       shard_size)
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

    def weight_loader_v2(self,
                         param: BasevLLMParameter,
                         loaded_weight: torch.Tensor,
                         loaded_shard_id: Optional[int] = None):
        if loaded_shard_id is None:
812
813
814
815
            if isinstance(param, PerTensorScaleParameter):
                param.load_merged_column_weight(loaded_weight=loaded_weight,
                                                shard_id=0)
                return
816
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
817
                param.load_merged_column_weight(loaded_weight=loaded_weight)
818
                return
819
            # TODO: @dsikka - move to parameter.py
820
821
822
823
824
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id < len(self.output_sizes)

825
826
        if isinstance(param, BlockQuantScaleParameter):
            assert self.quant_method is not None
827
828
829
            # Assume the weight block size has been set by quant method
            assert hasattr(self, "weight_block_size")
            weight_block_size = self.weight_block_size
830
831
832
833
            assert weight_block_size is not None
            block_n, _ = weight_block_size[0], weight_block_size[1]
            shard_offset = (
                (sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) //
834
                block_n) // self.tp_size
835
            shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) //
836
                          block_n // self.tp_size)
837
        else:
838
839
840
            shard_offset = sum(
                self.output_sizes[:loaded_shard_id]) // self.tp_size
            shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
841
842
843
844

        param.load_merged_column_weight(loaded_weight=loaded_weight,
                                        shard_id=loaded_shard_id,
                                        shard_offset=shard_offset,
845
846
                                        shard_size=shard_size,
                                        tp_rank=self.tp_rank)
847

848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869

class QKVParallelLinear(ColumnParallelLinear):
    """Linear layers for the attention's QKV transformation.

    Linear layers for the linear transformation of the query, key, and value
    vectors in the attention layer. The weight matrix is concatenated along
    the output dimension. The layer is parallelized along the head dimension.
    When the number of key/value heads is smaller than the number of query
    heads (e.g., multi-query/grouped-query attention), the key/value head may
    be replicated while the query heads are partitioned.

    Args:
        hidden_size: input hidden state size of the transformer.
        head_size: size of each attention head.
        total_num_heads: total number of attention query heads.
        total_num_kv_heads: total number of attention key/value heads. If
                            None, assume total_num_kv_heads = total_num_heads.
        bias: If true, add bias.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
870
        quant_config: Quantization configure.
871
872
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
873
        return_bias: If true, return bias together with outputs in forward pass.
874
        disable_tp: If true, weights matrix won't be sharded through tp rank.
875
876
    """

877
878
879
880
881
882
883
884
885
886
887
888
889
    def __init__(
        self,
        hidden_size: int,
        head_size: int,
        total_num_heads: int,
        total_num_kv_heads: Optional[int] = None,
        bias: bool = True,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        *,
        return_bias: bool = True,
890
        disable_tp: bool = False,
891
    ):
892
893
894
895
896
897
898
        self.hidden_size = hidden_size
        self.head_size = head_size
        self.total_num_heads = total_num_heads
        if total_num_kv_heads is None:
            total_num_kv_heads = total_num_heads
        self.total_num_kv_heads = total_num_kv_heads
        # Divide the weight matrix along the last dimension.
899
900
        tp_size = (get_tensor_model_parallel_world_size()
                   if not disable_tp else 1)
901
902
903
904
905
906
907
908
909
910
911
        self.num_heads = divide(self.total_num_heads, tp_size)
        if tp_size >= self.total_num_kv_heads:
            self.num_kv_heads = 1
            self.num_kv_head_replicas = divide(tp_size,
                                               self.total_num_kv_heads)
        else:
            self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
            self.num_kv_head_replicas = 1
        input_size = self.hidden_size
        output_size = (self.num_heads +
                       2 * self.num_kv_heads) * tp_size * self.head_size
912
913
914
915
        self.output_sizes = [
            self.num_heads * self.head_size * tp_size,  # q_proj
            self.num_kv_heads * self.head_size * tp_size,  # k_proj
            self.num_kv_heads * self.head_size * tp_size,  # v_proj 
James Fleming's avatar
James Fleming committed
916
917
        ]

918
919
920
921
922
923
        super().__init__(input_size=input_size,
                         output_size=output_size,
                         bias=bias,
                         gather_output=False,
                         skip_bias_add=skip_bias_add,
                         params_dtype=params_dtype,
924
                         quant_config=quant_config,
925
                         prefix=prefix,
926
927
                         return_bias=return_bias,
                         disable_tp=disable_tp)
928

929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
    def _get_shard_offset_mapping(self, loaded_shard_id: str):
        shard_offset_mapping = {
            "q": 0,
            "k": self.num_heads * self.head_size,
            "v": (self.num_heads + self.num_kv_heads) * self.head_size,
            "total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size
        }
        return shard_offset_mapping.get(loaded_shard_id)

    def _get_shard_size_mapping(self, loaded_shard_id: str):
        shard_size_mapping = {
            "q": self.num_heads * self.head_size,
            "k": self.num_kv_heads * self.head_size,
            "v": self.num_kv_heads * self.head_size,
        }
        return shard_size_mapping.get(loaded_shard_id)

    def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
                                           loaded_weight: torch.Tensor):
        """
        Handle special case for models where QKV layers are already 
        fused on disk. In this case, we have no shard id. This function
951
        determines the shard id by splitting these layers and then calls
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """
        shard_offsets = [
            # (shard_id, shard_offset, shard_size)
            ("q", 0, self.total_num_heads * self.head_size),
            ("k", self.total_num_heads * self.head_size,
             self.total_num_kv_heads * self.head_size),
            ("v",
             (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
             self.total_num_kv_heads * self.head_size),
        ]

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
971
972
            if isinstance(param, (PackedColumnParameter, PackedvLLMParameter
                                  )) and param.packed_dim == param.output_dim:
973
974
                shard_size, shard_offset = \
                    param.adjust_shard_indexes_for_packing(
975
976
977
978
979
980
981
982
983
984
985
986
                    shard_size=shard_size, shard_offset=shard_offset)

            loaded_weight_shard = loaded_weight.narrow(param.output_dim,
                                                       shard_offset,
                                                       shard_size)
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

    def weight_loader_v2(self,
                         param: BasevLLMParameter,
                         loaded_weight: torch.Tensor,
                         loaded_shard_id: Optional[str] = None):
        if loaded_shard_id is None:  # special case for certain models
987
            if isinstance(param, PerTensorScaleParameter):
988
989
990
                param.load_qkv_weight(loaded_weight=loaded_weight,
                                      shard_id=0,
                                      tp_rank=self.tp_rank)
991
                return
992
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
993
994
                param.load_qkv_weight(loaded_weight=loaded_weight,
                                      tp_rank=self.tp_rank)
995
                return
996
            # TODO: @dsikka - move to parameter.py
997
998
999
1000
1001
1002
1003
1004
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id in ["q", "k", "v"]

        shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
        shard_size = self._get_shard_size_mapping(loaded_shard_id)

1005
1006
1007
        # Note(simon): This is needed for Qwen3's fp8 quantization.
        if isinstance(param, BlockQuantScaleParameter):
            assert self.quant_method is not None
1008
1009
1010
1011
            # Assume the weight block size has been set by quant method
            assert hasattr(self, "weight_block_size")
            weight_block_size = self.weight_block_size
            assert weight_block_size is not None
1012
1013
1014
1015
            block_n, _ = weight_block_size[0], weight_block_size[1]
            shard_offset = (shard_offset + block_n - 1) // block_n
            shard_size = (shard_size + block_n - 1) // block_n

1016
1017
1018
1019
        param.load_qkv_weight(loaded_weight=loaded_weight,
                              num_heads=self.num_kv_head_replicas,
                              shard_id=loaded_shard_id,
                              shard_offset=shard_offset,
1020
1021
                              shard_size=shard_size,
                              tp_rank=self.tp_rank)
1022

1023
1024
1025
1026
    def weight_loader(self,
                      param: Parameter,
                      loaded_weight: torch.Tensor,
                      loaded_shard_id: Optional[str] = None):
1027
1028
1029
1030
1031

        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
1032
        if is_gguf_weight_type:
1033
            idx_map = {"q": 0, "k": 1, "v": 2}
1034
1035
1036
1037
1038
1039
1040
1041
            if loaded_shard_id is not None:
                param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
                param.shard_weight_type = {
                    k: loaded_weight.item()
                    for k in idx_map
                }
1042
1043
            return

1044
1045
        if is_gguf_weight:
            output_dim = getattr(param, "output_dim", None)
1046
1047
            shard_size = loaded_weight.size(output_dim) // self.tp_size
            start_idx = self.tp_rank * shard_size
1048

1049
1050
1051
1052
1053
1054
1055
            if loaded_shard_id is not None:
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
1056

1057
1058
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
1059

1060
1061
        # Special case for per-tensor scales in fused case.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
1062

1063
        if loaded_shard_id is None:
1064
1065
            # Loaded weight is already fused on disk (qkv).
            # (e.g., Phi-3's qkv_proj).
1066
            if output_dim is None:
1067
                if needs_scalar_to_array:
1068
1069
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
                        param_data, loaded_weight, 0)
1070

1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            shard_offsets = [
                # (shard_id, shard_offset, shard_size)
                ("q", 0, self.total_num_heads * self.head_size),
                ("k", self.total_num_heads * self.head_size,
                 self.total_num_kv_heads * self.head_size),
                ("v", (self.total_num_heads + self.total_num_kv_heads) *
                 self.head_size, self.total_num_kv_heads * self.head_size),
            ]
1082
1083
1084
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)

1085
1086
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
1087
                # Special case for Quantized Weights.
1088
1089
1090
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
1091
1092
                    shard_size = shard_size // param.packed_factor
                    shard_offset = shard_offset // param.packed_factor
1093

1094
                    # Special case for Marlin.
1095
1096
1097
                    shard_size, shard_offset = adjust_marlin_shard(
                        param, shard_size, shard_offset)

1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
                if use_bitsandbytes_4bit:
                    orig_qkv_offsets = {
                        "q": (0, self.total_num_heads * self.head_size),
                        "k": (self.total_num_heads * self.head_size,
                              self.total_num_kv_heads * self.head_size),
                        "v":
                        ((self.total_num_heads + self.total_num_kv_heads) *
                         self.head_size,
                         self.total_num_kv_heads * self.head_size),
                        "total":
                        ((self.total_num_heads + 2 * self.total_num_kv_heads) *
                         self.head_size, 0)
                    }

                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
                        param, orig_qkv_offsets, shard_id)

1115
1116
1117
1118
1119
1120
                loaded_weight_shard = loaded_weight.narrow(
                    output_dim, shard_offset, shard_size)
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id in ["q", "k", "v"]
1121
1122

        # If output dim is defined, use the default loading process.
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
        if output_dim is not None:
            if loaded_shard_id == "q":
                shard_offset = 0
                shard_size = self.num_heads * self.head_size
            elif loaded_shard_id == "k":
                shard_offset = self.num_heads * self.head_size
                shard_size = self.num_kv_heads * self.head_size
            elif loaded_shard_id == "v":
                shard_offset = (self.num_heads +
                                self.num_kv_heads) * self.head_size
                shard_size = self.num_kv_heads * self.head_size
1134
            # Special case for Quantized Weights.
1135
1136
1137
1138
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
1139
1140
                shard_size = shard_size // param.packed_factor
                shard_offset = shard_offset // param.packed_factor
1141

1142
                # Special case for Marlin.
1143
1144
1145
                shard_size, shard_offset = adjust_marlin_shard(
                    param, shard_size, shard_offset)

1146
1147
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)
1148
1149
1150
1151
1152
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

1153
            if use_bitsandbytes_4bit:
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
                orig_qkv_offsets = {
                    "q": (0, self.num_heads * self.head_size),
                    "k": (self.num_heads * self.head_size,
                          self.num_kv_heads * self.head_size),
                    "v":
                    ((self.num_heads + self.num_kv_heads) * self.head_size,
                     self.num_kv_heads * self.head_size),
                    "total":
                    ((self.num_heads + 2 * self.num_kv_heads) * self.head_size,
                     0)
                }
1165
                shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
1166
1167
                    param, orig_qkv_offsets, loaded_shard_id)

1168
1169
            param_data = param_data.narrow(output_dim, shard_offset,
                                           shard_size)
1170
            if loaded_shard_id == "q":
1171
                shard_id = self.tp_rank
1172
            else:
1173
                shard_id = self.tp_rank // self.num_kv_head_replicas
1174
            start_idx = shard_id * shard_size
1175

1176
            if not is_sharded_weight:
1177
1178
1179
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)

1180
1181
1182
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
1183
                param_data, loaded_weight, loaded_shard_id)
1184
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
1185
1186
1187
1188
1189
1190
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "QKVParallelLinear, assume the weight is the same "
                    "for all partitions.")
1191

1192
1193
1194
1195
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)


1196
@CustomOp.register("row_parallel_linear")
1197
class RowParallelLinear(LinearBase):
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
    """Linear layer with row parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its first dimension and X along its second dimension as:
               -   -
              | A_1 |
              | .   |
          A = | .   |        X = [X_1, ..., X_p]
              | .   |
              | A_p |
               -   -
    Arguments:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias. Note that bias is not parallelized.
        input_is_parallel: If true, we assume that the input is already
                           split across the GPUs and we do not split
                           again.
        skip_bias_add: This was added to enable performance optimization where
                       bias can be fused with other element-wise operations.
                       We skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
1220
1221
1222
        reduce_results: If true, call all-reduce on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y = X_iA_i
1223
        quant_config: Quantization configure.
1224
1225
1226
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.down_proj)
        return_bias: If true, return bias together with outputs in forward pass.
1227
        disable_tp: If true, weights matrix won't be sharded through tp rank.
1228
1229
    """

1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        input_is_parallel: bool = True,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        reduce_results: bool = True,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        *,
        return_bias: bool = True,
1243
        disable_tp: bool = False,
1244
    ):
1245
        # Divide the weight matrix along the first dimension.
1246
1247
1248
1249
        self.tp_rank = (get_tensor_model_parallel_rank()
                        if not disable_tp else 0)
        self.tp_size = (get_tensor_model_parallel_world_size()
                        if not disable_tp else 1)
1250
1251
1252
1253
        self.input_size_per_partition = divide(input_size, self.tp_size)
        self.output_size_per_partition = output_size
        self.output_partition_sizes = [output_size]

1254
1255
1256
1257
1258
1259
        super().__init__(input_size,
                         output_size,
                         skip_bias_add,
                         params_dtype,
                         quant_config,
                         prefix,
1260
1261
                         return_bias=return_bias,
                         disable_tp=disable_tp)
1262

1263
1264
1265
        self.input_is_parallel = input_is_parallel
        self.reduce_results = reduce_results

1266
        assert self.quant_method is not None
1267
1268
1269
        self.quant_method.create_weights(
            layer=self,
            input_size_per_partition=self.input_size_per_partition,
1270
            output_partition_sizes=self.output_partition_sizes,
1271
1272
1273
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
1274
1275
            weight_loader=(
                self.weight_loader_v2 if self.quant_method.__class__.__name__
1276
                in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
1277
1278
1279
1280
1281
1282
        if not reduce_results and (bias and not skip_bias_add):
            raise ValueError("When not reduce the results, adding bias to the "
                             "results can lead to incorrect results")

        if bias:
            self.bias = Parameter(
1283
                torch.empty(self.output_size, dtype=params_dtype))
1284
1285
1286
1287
1288
1289
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
        else:
            self.register_parameter("bias", None)
1290
        self.update_param_tp_status()
1291
1292
1293

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        input_dim = getattr(param, "input_dim", None)
1294
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1295
1296
1297
1298
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
1299
1300
1301
1302
1303
1304
1305
1306
1307

        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
1308
1309
            weight_shape = list(loaded_weight.shape)
            if input_dim:
1310
1311
                weight_shape[input_dim] = (weight_shape[input_dim] //
                                           self.tp_size)
1312
            param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
1313

1314
        param_data = param.data
1315
        if input_dim is not None and not is_sharded_weight:
1316
            shard_size = param_data.shape[input_dim]
1317
            start_idx = self.tp_rank * shard_size
1318
1319
            loaded_weight = loaded_weight.narrow(input_dim, start_idx,
                                                 shard_size)
1320

1321
1322
1323
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
1324
1325
            loaded_weight = loaded_weight.reshape(1)

1326
1327
1328
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

1329
1330
    def weight_loader_v2(self, param: BasevLLMParameter,
                         loaded_weight: torch.Tensor):
1331
1332
1333
1334
1335
1336
1337

        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)

1338
1339
        param.load_row_parallel_weight(loaded_weight=loaded_weight)

1340
    def forward(
1341
1342
        self,
        input_,
1343
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
1344
1345
1346
1347
1348
        if self.input_is_parallel:
            input_parallel = input_
        else:
            splitted_input = split_tensor_along_last_dim(
                input_, num_partitions=self.tp_size)
1349
            input_parallel = splitted_input[self.tp_rank].contiguous()
1350
1351

        # Matrix multiply.
1352
        assert self.quant_method is not None
1353
1354
1355
        # Only fuse bias add into GEMM for rank 0 (this ensures that
        # bias will not get added more than once in TP>1 case)
        bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
1356
1357
        output_parallel = self.quant_method.apply(self, input_parallel, bias_)

1358
        if self.reduce_results and self.tp_size > 1:
1359
            output = tensor_model_parallel_all_reduce(output_parallel)
1360
        else:
1361
1362
1363
            output = output_parallel

        output_bias = self.bias if self.skip_bias_add else None
1364

1365
1366
        if not self.return_bias:
            return output
1367
        return output, output_bias
1368
1369

    def extra_repr(self) -> str:
1370
        s = f"in_features={self.input_size_per_partition}"
1371
1372
1373
1374
1375
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={self.tp_size}"
        s += f", reduce_results={self.reduce_results}"
        return s
1376
1377


1378
@CustomOp.register("qkv_cross_parallel_linear")
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
class QKVCrossParallelLinear(LinearBase):
    """Linear layers for efficient cross-attention's QKV transformation.

    Args:
        hidden_size: input hidden state size of the transformer.
        head_size: size of each attention head.
        total_num_heads: total number of attention query heads.
        total_num_kv_heads: total number of attention key/value heads. If
                            None, assume total_num_kv_heads = total_num_heads.
        bias: If true, add bias.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
        quant_config: Quantization configure.
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
    """
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407

    def __init__(self,
                 hidden_size: int,
                 head_size: int,
                 total_num_heads: int,
                 total_num_kv_heads: Optional[int] = None,
                 bias: bool = True,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
        # input_size and output_size are not used, just for alignment
        input_size = hidden_size
        output_size = (total_num_heads + (total_num_kv_heads or 0)) * head_size
        super().__init__(input_size=input_size,
                         output_size=output_size,
                         skip_bias_add=skip_bias_add,
                         params_dtype=params_dtype,
                         quant_config=quant_config,
                         prefix=prefix)

        self.quant_config = quant_config

1420
        # Empty placeholders for loading as a single module.
1421
1422
1423
1424
1425
1426
1427
1428
1429
        placeholder_size = 0
        assert self.quant_method is not None
        self.quant_method.create_weights(self,
                                         placeholder_size, [placeholder_size],
                                         placeholder_size,
                                         placeholder_size,
                                         self.params_dtype,
                                         weight_loader=self.weight_loader)

1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
        # Use a dictionary to avoid submodules parameters auto-registration:
        # drop-in replacement for a `QKVParallelLinear` module.
        self.proj = dict()
        self.proj["q_proj_decoder"] = ColumnParallelLinear(
            input_size=hidden_size,
            output_size=total_num_heads * head_size,
            bias=bias,
            quant_config=quant_config,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            prefix=f"{prefix}.q_proj_decoder")

        self.proj["kv_proj_encoder"] = QKVParallelLinear(
            hidden_size=hidden_size,
            head_size=head_size,
            total_num_heads=0,
            total_num_kv_heads=total_num_kv_heads,
            bias=bias,
            quant_config=quant_config,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            prefix=f"{prefix}.kv_proj_encoder")

        # `kv_proj_encoder.num_kv_heads` accounts for sharding with tp>1.
1454
        self.q_size = self.q_proj_decoder.output_size_per_partition
1455
1456
1457
1458
1459
        self.kv_size = self.kv_proj_encoder.num_kv_heads * head_size

        if bias:
            self.bias = torch.nn.Parameter()
            set_weight_attrs(self.bias, {
1460
                "output_dim": 0,
1461
                "weight_loader": self.weight_loader_v1,
1462
            })
1463
1464
        else:
            self.bias = None
1465

1466
1467
1468
1469
1470
    def process_weights_after_loading(self):
        for layer in self.proj.values():
            if self.quant_method is not None:
                self.quant_method.process_weights_after_loading(layer)

1471
    @property
1472
1473
1474
    def q_proj_decoder(self) -> ColumnParallelLinear:
        layer = self.proj["q_proj_decoder"]
        for name, param in self.named_parameters():
1475
1476
1477
1478
1479
            target_param = getattr(layer, name, None)
            if target_param is not None:
                self.sync_weight_attrs(param,
                                       target_param,
                                       mode="q_proj_decoder")
1480
        return layer
1481
1482

    @property
1483
1484
1485
    def kv_proj_encoder(self) -> QKVParallelLinear:
        layer = self.proj["kv_proj_encoder"]
        for name, param in self.named_parameters():
1486
1487
1488
1489
1490
            target_param = getattr(layer, name, None)
            if target_param is not None:
                self.sync_weight_attrs(param,
                                       target_param,
                                       mode="kv_proj_encoder")
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
        return layer

    def sync_weight_attrs(
        self,
        src_param: nn.Parameter,
        tgt_param: nn.Parameter,
        mode: Literal["q_proj_decoder", "kv_proj_encoder"],
    ):
        missing_attrs_dict = {
            k: getattr(src_param, k)
1501
1502
            for k in (set(vars(src_param).keys()) -
                      set(vars(tgt_param).keys()))
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
        }
        # TODO(Isotr0py): handle bitsandbytes 8bit
        use_bitsandbytes_4bit = getattr(src_param, "use_bitsandbytes_4bit",
                                        False)
        if (missing_attrs_dict and use_bitsandbytes_4bit):
            q_proj_attrs, kv_proj_attrs = left_shift_bitsandbytes_4bit_shard(
                missing_attrs_dict)
            if mode == "q_proj_decoder":
                set_weight_attrs(tgt_param, q_proj_attrs)
            elif mode == "kv_proj_encoder":
                set_weight_attrs(tgt_param, kv_proj_attrs)
        else:
            set_weight_attrs(tgt_param, missing_attrs_dict)
1516

1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
    def _is_same_param(
        self,
        src_param: torch.nn.Parameter,
        map_param: torch.nn.Parameter,
    ) -> bool:
        """Check if two parameters are exactly pointing to same things."""
        # ignore weight_loader because it's always different
        key_to_ignore = ["weight_loader", "_weight_loader"]
        has_same_type_name = type(src_param) is type(map_param)
        src_param_attrs = {
            k: v
            for k, v in src_param.__dict__.items() if k not in key_to_ignore
        }
        map_param_attrs = {
            k: v
            for k, v in map_param.__dict__.items() if k not in key_to_ignore
        }
        has_same_attrs = src_param_attrs == map_param_attrs
        return has_same_type_name and has_same_attrs

    def select_proj_params(
        self,
        layer: nn.Module,
        param: nn.Parameter,
    ) -> nn.Parameter:
        """
        Given the placeholder param, 
        return the corresponding param in the proj layers.
        """
        target_param_list = [
            v for _, v in layer.named_parameters()
            if self._is_same_param(param, v)
        ]
        assert len(target_param_list) == 1
        target_param = target_param_list[0]
        return target_param

    def forward(  # type: ignore[override]
        self,
        decoder_hidden_states: torch.Tensor,
        encoder_hidden_states: torch.Tensor,
    ) -> tuple[torch.Tensor, ...]:
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
        q, _ = self.q_proj_decoder(decoder_hidden_states)
        if encoder_hidden_states is None:
            # Encoder KV already cached.
            k = None
            v = None
        else:
            # Prefill phase, encoder KV cached here.
            kv_enc, _ = self.kv_proj_encoder(encoder_hidden_states)
            # Split kv in half
            k, v = kv_enc.split(self.kv_size, dim=-1)
        return q, k, v

1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
    def weight_loader_v1(self,
                         param: torch.nn.Parameter,
                         loaded_weight: torch.Tensor,
                         loaded_shard_id: Optional[str] = None):
        # just like all other parameters, does not yet
        # support loading bias with weight_loader_v2
        layer = (self.q_proj_decoder
                 if loaded_shard_id == "q" else self.kv_proj_encoder)
        target_param = self.select_proj_params(layer, param)
        shard_id_args = (loaded_shard_id, ) if loaded_shard_id != "q" else ()
        layer.weight_loader(target_param, loaded_weight, *shard_id_args)

1583
1584
1585
1586
1587
1588
1589
1590
    def weight_loader(self,
                      param: torch.nn.Parameter,
                      loaded_weight: torch.Tensor,
                      loaded_shard_id: Optional[str] = None):
        layer = (self.q_proj_decoder
                 if loaded_shard_id == "q" else self.kv_proj_encoder)
        target_param = self.select_proj_params(layer, param)
        shard_id_args = (loaded_shard_id, ) if loaded_shard_id != "q" else ()
1591
1592
1593
1594
        if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED:
            layer.weight_loader_v2(target_param, loaded_weight, *shard_id_args)
        else:
            layer.weight_loader(target_param, loaded_weight, *shard_id_args)
1595
1596
1597

    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
1598
        s += f", q_size={self.q_size}"
1599
1600
1601
1602
1603
        s += f", kv_size={self.kv_size}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={get_tensor_model_parallel_world_size()}"
        s += ", gather_output=False"
        return s