linear.py 67.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
from abc import abstractmethod
6
from typing import Any, Literal, Optional, Union
7
8

import torch
9
import torch.nn as nn
10
from torch.nn.parameter import Parameter, UninitializedParameter
11

12
13
14
15
16
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size,
                              split_tensor_along_last_dim,
                              tensor_model_parallel_all_gather,
                              tensor_model_parallel_all_reduce)
17
from vllm.logger import init_logger
18
from vllm.model_executor.custom_op import CustomOp
19
20
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig, QuantizeMethodBase)
21
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
22
# yapf: disable
23
from vllm.model_executor.parameter import (BasevLLMParameter,
24
                                           BlockQuantScaleParameter,
25
                                           ModelWeightParameter,
26
                                           PackedColumnParameter,
27
                                           PackedvLLMParameter,
28
29
                                           PerTensorScaleParameter,
                                           RowvLLMParameter)
30
# yapf: enable
31
from vllm.model_executor.utils import set_weight_attrs
32
from vllm.platforms import current_platform
33
from vllm.utils import GiB_bytes
34
35
36

logger = init_logger(__name__)

37
WEIGHT_LOADER_V2_SUPPORTED = [
38
    "UnquantizedLinearMethod",
39
    "CompressedTensorsLinearMethod",
40
    "CompressedTensorsLinearTransformMethod",
41
42
    "BitBLASLinearMethod",
    "GPTQBitBLASLinearMethod",
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
    "AWQMarlinLinearMethod",
    "AWQLinearMethod",
    "GPTQMarlinLinearMethod",
    "Fp8LinearMethod",
    "MarlinLinearMethod",
    "GPTQMarlin24LinearMethod",
    "TPUInt8LinearMethod",
    "GPTQLinearMethod",
    "FBGEMMFp8LinearMethod",
    "ModelOptFp8LinearMethod",
    "IPEXAWQLinearMethod",
    "IPEXGPTQLinearMethod",
    "HQQMarlinMethod",
    "QuarkLinearMethod",
    "ModelOptNvFp4LinearMethod",
58
    "PetitNvFp4LinearMethod",
59
]
60

61

62
63
64
65
66
67
68
69
70
def adjust_bitblas_shard(param, shard_size, shard_offset):
    bitblas_tile_size = getattr(param, "bitblas_tile_size", None)
    if bitblas_tile_size is not None:
        return (shard_size // bitblas_tile_size,
                shard_offset // bitblas_tile_size)

    return shard_size, shard_offset


71
72
73
74
75
76
77
78
def adjust_marlin_shard(param, shard_size, shard_offset):
    marlin_tile_size = getattr(param, "marlin_tile_size", None)
    if marlin_tile_size is None:
        return shard_size, shard_offset

    return shard_size * marlin_tile_size, shard_offset * marlin_tile_size


79
def adjust_bitsandbytes_4bit_shard(param: Parameter,
80
81
                                   shard_offsets: dict[str, tuple[int, int]],
                                   loaded_shard_id: str) -> tuple[int, int]:
82
83
    """Adjust the quantization offsets and sizes for BitsAndBytes sharding."""

84
85
    total, _ = shard_offsets["total"]
    orig_offset, orig_size = shard_offsets[loaded_shard_id]
86
87
88
89
90
91
92
93

    quantized_total = param.data.shape[0]
    quantized_offset = orig_offset * quantized_total // total
    quantized_size = orig_size * quantized_total // total

    return quantized_size, quantized_offset


94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
    """For fused modules (QKV and MLP) we have an array of length
    N that holds 1 scale for each "logical" matrix. So the param
    is an array of length N. The loaded_weight corresponds to 
    one of the shards on disk. Here, we slice the param based on 
    the shard_id for loading.
    """
    qkv_idxs = {"q": 0, "k": 1, "v": 2}

    if isinstance(shard_id, str):
        shard_id = qkv_idxs[shard_id]
    elif not isinstance(shard_id, int):
        raise ValueError(f"Unknown Shard Id {shard_id}")

    # AutoFP8 scales do not have a shape
    # compressed-tensors scales do have a shape
    if len(loaded_weight.shape) != 0:
        assert loaded_weight.shape[0] == 1
        loaded_weight = loaded_weight[0]

    return param[shard_id], loaded_weight


117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
# TODO(Isotr0py): We might need a more flexible structure to handle
# bitsandbytes shard offsets.
def left_shift_bitsandbytes_4bit_shard(bnb_weight_attrs: dict[str, Any]):
    """
    Separate the BitsAndBytes 4-bit shard.

    For example, given bnb weight attributes as below:
    {
        'bnb_shard_offsets': array([0, 4, 8, 16]), 
        'bnb_quant_state': {0: ..., 1: ..., 2: ...},
    }

    The function will return:
    {
        'bnb_shard_offsets': array([0, 4]), 
        'bnb_quant_state': {0: ...},
    }
    and
    {
        'bnb_shard_offsets': array([0, 4, 12]),
        'bnb_quant_state': {0: ..., 1: ...},
    }
    """
    shard_offsets = bnb_weight_attrs["bnb_shard_offsets"]
    offset_l = shard_offsets[:2]
    offset_r = shard_offsets[1:] - shard_offsets[1]
    quant_state_l = {0: bnb_weight_attrs["bnb_quant_state"][0]}
    quant_state_r = {
        i - 1: bnb_weight_attrs["bnb_quant_state"][i]
        for i in range(1,
                       len(shard_offsets) - 1)
    }
    left = dict(bnb_shard_offsets=offset_l, bnb_quant_state=quant_state_l)
    right = dict(bnb_shard_offsets=offset_r, bnb_quant_state=quant_state_r)
    return left, right


154
class LinearMethodBase(QuantizeMethodBase):
155
156
157
    """Base class for different (maybe quantized) linear methods."""

    @abstractmethod
158
159
    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
160
                       output_partition_sizes: list[int], input_size: int,
161
162
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
163
164
        """Create weights for a linear layer. 
           The weights will be set as attributes of the layer.
165

166
167
168
169
170
171
172
173
174
175
        Args:
            layer: The layer that is using the LinearMethodBase factory.
            input_size_per_partition: Size of the weight input dim on rank X.
            output_partition_sizes: Sizes of the output dim of each logical 
                weight on rank X. E.g., output_partition_sizes for QKVLinear
                is a list contains the width of Wq, Wk, Wv on rank X.
            input_size: Size of the input dim of the weight across all ranks.
            output_size: Size of the output dim of the weight across all ranks.
            params_dtype: Datatype of the parameters.
        """
176
177
178
        raise NotImplementedError

    @abstractmethod
179
180
181
182
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
183
184
        """Apply the weights in layer to the input tensor.
        Expects create_weights to have been called before on the layer."""
185
186
187
188
        raise NotImplementedError


class UnquantizedLinearMethod(LinearMethodBase):
189
    """Linear method without quantization."""
190

191
192
    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
193
                       output_partition_sizes: list[int], input_size: int,
194
195
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
196
197
198
199
200
        # This method creates unquantized linear weights.
        # The weights are not quantized, and they are not sharded.
        # The amount of memory allocated for the weights is
        # sum(output_partition_sizes) * input_size_per_partition.
        try:
201
202
203
204
205
206
207
208
            weight_loader = extra_weight_attrs.pop("weight_loader")
            weight = ModelWeightParameter(data=torch.empty(
                sum(output_partition_sizes),
                input_size_per_partition,
                dtype=params_dtype),
                                          input_dim=1,
                                          output_dim=0,
                                          weight_loader=weight_loader)
209
210
211
212
213
214
215
216
217
218
219
220
        except torch.cuda.OutOfMemoryError as e:
            logger.error("Failed to create unquantized linear weights: %s", e)
            if torch.cuda.is_available():
                logger.debug("CUDA device: %s", torch.cuda.current_device())
                logger.debug("Allocated: %.2f GiB",
                             torch.cuda.memory_allocated() / GiB_bytes)
                logger.debug("Reserved: %.2f GiB",
                             torch.cuda.memory_reserved() / GiB_bytes)
            raise RuntimeError(
                "Failed to create unquantized linear weights. "
                "This may be caused by insufficient memory to allocate "
                "the weight.") from e
221

222
223
        layer.register_parameter("weight", weight)
        set_weight_attrs(weight, extra_weight_attrs)
224

225
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
226
227
228
229
        if current_platform.is_cpu():
            from vllm.model_executor.layers.utils import (
                dispatch_cpu_unquantized_gemm)
            dispatch_cpu_unquantized_gemm(layer, remove_weight=True)
230

231
232
233
234
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
235

236
        return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
237
238


239
class LinearBase(CustomOp):
240
    """Base linear layer.
241
242
243
244
245
246

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
247
        quant_config: Quantization configure.
248
        prefix: Prefix for parameter names.
249
        return_bias: If true, return bias together with outputs in forward pass.
250
        disable_tp: If true, tensor parallelism will be disabled for this layer.
251
252
253
254
255
256
257
258
    """

    def __init__(
        self,
        input_size: int,
        output_size: int,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
259
        quant_config: Optional[QuantizationConfig] = None,
260
        prefix: str = "",
261
262
        *,
        return_bias: bool = True,
263
        disable_tp: bool = False,
264
265
266
267
268
269
270
271
272
273
    ):
        super().__init__()

        # Keep input parameters
        self.input_size = input_size
        self.output_size = output_size
        self.skip_bias_add = skip_bias_add
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype
274
275
        self.quant_config = quant_config
        self.prefix = prefix
276
        if quant_config is None:
277
278
            self.quant_method: Optional[
                QuantizeMethodBase] = UnquantizedLinearMethod()
279
        else:
280
281
            self.quant_method = quant_config.get_quant_method(self,
                                                              prefix=prefix)
282
        self.return_bias = return_bias
283
284
285
286
287
288
        self.disable_tp = disable_tp
        self.tp_rank = (get_tensor_model_parallel_rank()
                        if not disable_tp else 0)
        self.tp_size = (get_tensor_model_parallel_world_size()
                        if not disable_tp else 1)

289
    def update_param_tp_status(self):
290
291
292
293
        for param in self.parameters():
            if isinstance(param, BasevLLMParameter):
                param.tp_rank = self.tp_rank
                param.tp_size = self.tp_size
294
295


296
@CustomOp.register("replicated_linear")
297
298
299
300
301
302
303
304
305
306
class ReplicatedLinear(LinearBase):
    """Replicated linear layer.

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        bias: If true, add bias.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
        quant_config: Quantization configure.
307
308
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
309
        return_bias: If true, return bias together with outputs in forward pass.
310
        disable_tp: Take no effect for replicated linear layers.
311
312
    """

313
314
315
316
317
318
319
320
321
322
323
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        *,
        return_bias: bool = True,
324
        disable_tp: bool = False,
325
    ):
326
327
328
329
330
        super().__init__(input_size,
                         output_size,
                         skip_bias_add,
                         params_dtype,
                         quant_config,
331
                         prefix=prefix,
332
333
                         return_bias=return_bias,
                         disable_tp=disable_tp)
334

335
336
        # All the linear layer supports quant method.
        assert self.quant_method is not None
337
        self.quant_method.create_weights(self,
338
                                         self.input_size, [self.output_size],
339
340
341
                                         self.input_size,
                                         self.output_size,
                                         self.params_dtype,
342
                                         weight_loader=self.weight_loader)
343

344
345
        if bias:
            self.bias = Parameter(
346
                torch.empty(self.output_size, dtype=self.params_dtype))
347
348
349
350
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
351
352
353
        else:
            self.register_parameter("bias", None)

354
355
356
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        # If the weight on disk does not have a shape, give it one
        # (such scales for AutoFp8).
357
358
359
360
361
362
363
364
365
366
367
        # Special case for GGUF

        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
            param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)

368
369
370
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)

371
372
373
        assert param.size() == loaded_weight.size(), (
            f"Tried to load weights of size {loaded_weight.size()}"
            f"to a parameter of size {param.size()}")
374
375
        param.data.copy_(loaded_weight)

376
377
378
    def forward(
        self, x: torch.Tensor
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
379
        bias = self.bias if not self.skip_bias_add else None
380
        assert self.quant_method is not None
381
        output = self.quant_method.apply(self, x, bias)
382
        output_bias = self.bias if self.skip_bias_add else None
383
384
        if not self.return_bias:
            return output
385
386
        return output, output_bias

387
388
389
390
391
392
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        return s

393

394
@CustomOp.register("column_parallel_linear")
395
class ColumnParallelLinear(LinearBase):
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
    """Linear layer with column parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its second dimension as A = [A_1, ..., A_p].

    Args:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y_i = XA_i
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
412
        quant_config: Quantization configure.
James Fleming's avatar
James Fleming committed
413
414
        output_sizes: list of output sizes packed into one output, like for QKV
                       the list would be size 3.
415
        prefix: The name of the layer in the state dict, including all parents
416
417
418
                        (e.g. model.layers.0.qkv_proj)
        return_bias: If true, return bias together with outputs in forward pass.
        disable_tp: If true, weights matrix won't be sharded through tp rank.
419
420
    """

421
422
423
424
425
426
427
428
429
430
431
432
433
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
        output_sizes: Optional[list[int]] = None,
        prefix: str = "",
        *,
        return_bias: bool = True,
434
        disable_tp: bool = False,
435
    ):
436
        # Divide the weight matrix along the last dimension.
437
438
439
440
        self.tp_rank = (get_tensor_model_parallel_rank()
                        if not disable_tp else 0)
        self.tp_size = (get_tensor_model_parallel_world_size()
                        if not disable_tp else 1)
441
442
        self.input_size_per_partition = input_size
        self.output_size_per_partition = divide(output_size, self.tp_size)
443
444
445
446
        self.output_partition_sizes = [self.output_size_per_partition]
        # If QKV or MergedColumn, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = [
447
                divide(output_size, self.tp_size)
448
449
450
                for output_size in self.output_sizes
            ]

451
452
453
454
455
456
        super().__init__(input_size,
                         output_size,
                         skip_bias_add,
                         params_dtype,
                         quant_config,
                         prefix,
457
458
                         return_bias=return_bias,
                         disable_tp=disable_tp)
459
460
461

        self.gather_output = gather_output

James Fleming's avatar
James Fleming committed
462
463
        if output_sizes is None:
            output_sizes = [output_size]
464

465
        assert self.quant_method is not None
466
467
        self.quant_method.create_weights(
            layer=self,
468
            input_size_per_partition=self.input_size_per_partition,
469
470
471
472
            output_partition_sizes=self.output_partition_sizes,
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
473
474
            weight_loader=(
                self.weight_loader_v2 if self.quant_method.__class__.__name__
475
                in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
476
477
478
479
480
481
482
483
484
485
        if bias:
            self.bias = Parameter(
                torch.empty(self.output_size_per_partition,
                            dtype=params_dtype))
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
        else:
            self.register_parameter("bias", None)
486
        self.update_param_tp_status()
487
488

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
489

490
        output_dim = getattr(param, "output_dim", None)
491

492
493
494
495
496
497
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

498
499
500
501
502
503
504
505
        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
506
507
            final_shape = list(loaded_weight.shape)
            if output_dim is not None:
508
509
510
                assert final_shape[output_dim] % self.tp_size == 0
                final_shape[output_dim] = (final_shape[output_dim] //
                                           self.tp_size)
511
            param.materialize(final_shape, dtype=loaded_weight.dtype)
512

513
        param_data = param.data
514
        if output_dim is not None and not is_sharded_weight:
515
            shard_size = param_data.shape[output_dim]
516
            start_idx = self.tp_rank * shard_size
517
518
            loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                 shard_size)
519
520
521
522
523

        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)
524

525
526
527
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

528
529
    def weight_loader_v2(self, param: BasevLLMParameter,
                         loaded_weight: torch.Tensor):
530
531
532
533
534
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)
535
536
        param.load_column_parallel_weight(loaded_weight=loaded_weight)

537
538
539
    def forward(
        self, input_
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
540
541
542
        bias = self.bias if not self.skip_bias_add else None

        # Matrix multiply.
543
        assert self.quant_method is not None
544
        output_parallel = self.quant_method.apply(self, input_, bias)
545
        if self.gather_output and self.tp_size > 1:
546
547
548
549
550
            # All-gather across the partitions.
            output = tensor_model_parallel_all_gather(output_parallel)
        else:
            output = output_parallel
        output_bias = self.bias if self.skip_bias_add else None
551
552
        if not self.return_bias:
            return output
553
554
        return output, output_bias

555
556
557
558
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size_per_partition}"
        s += f", bias={self.bias is not None}"
559
        s += f", tp_size={self.tp_size}"
560
561
562
        s += f", gather_output={self.gather_output}"
        return s

563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581

class MergedColumnParallelLinear(ColumnParallelLinear):
    """Packed linear layers with column parallelism.

    Similar to ColumnParallelLinear, but the weight matrix is concatenated
    along the output dimension. When the weight matrix is loaded, the
    different partitions are sharded separately.

    Args:
        input_size: input dimension of the linear layer.
        output_sizes: list of output dimensions of the linear layer.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make the output
                       available to all GPUs, otherwise, every GPU will have
                       its own output.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
582
        quant_config: Quantization configure.
583
584
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
585
        return_bias: If true, return bias together with outputs in forward pass.
586
587
        disable_tp: If true, all weights matrix won't be sharded, this layer
                    will be treated as a "Replicated" MergedLinear.
588
589
    """

590
591
592
593
594
595
596
597
598
599
600
601
    def __init__(
        self,
        input_size: int,
        output_sizes: list[int],
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        *,
        return_bias: bool = True,
602
        disable_tp: bool = False,
603
    ):
604
        self.output_sizes = output_sizes
605
606
607
608
        self.tp_size = (get_tensor_model_parallel_world_size()
                        if not disable_tp else 1)
        self.tp_rank = (get_tensor_model_parallel_rank()
                        if not disable_tp else 0)
609
610
611

        assert all(output_size % self.tp_size == 0
                   for output_size in output_sizes)
612
613
614
615
616
617
        super().__init__(input_size=input_size,
                         output_size=sum(output_sizes),
                         bias=bias,
                         gather_output=gather_output,
                         skip_bias_add=skip_bias_add,
                         params_dtype=params_dtype,
618
                         quant_config=quant_config,
619
                         prefix=prefix,
620
621
                         return_bias=return_bias,
                         disable_tp=disable_tp)
622
623
624
625
626

    def weight_loader(self,
                      param: Parameter,
                      loaded_weight: torch.Tensor,
                      loaded_shard_id: Optional[int] = None):
James Fleming's avatar
James Fleming committed
627

628
629
630
631
632
        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
633
634
635
636
637
638
639
640
            if loaded_shard_id is not None:
                param.data[loaded_shard_id].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
                param.shard_weight_type = {
                    i: loaded_weight.item()
                    for i, _ in enumerate(self.output_sizes)
                }
641
642
            return

643
644
645
        if is_gguf_weight:

            output_dim = getattr(param, "output_dim", None)
646
647
            shard_size = loaded_weight.size(output_dim) // self.tp_size
            start_idx = self.tp_rank * shard_size
648

649
650
651
652
653
654
655
            if loaded_shard_id is not None:
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
656

657
658
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
659
660
        # Special case for per-tensor scale to load scalar into fused array.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
661

662
        if loaded_shard_id is None:
663
664
            # Loaded weight is already fused on disk (mlp).
            # (e.g., Phi-3's gate_up_proj).
665
            if output_dim is None:
666
                if needs_scalar_to_array:
667
668
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
                        param_data, loaded_weight, 0)
669

670
671
672
673
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            current_shard_offset = 0
674
675
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)
676
            shard_offsets: list[tuple[int, int, int]] = []
677
678
679
680
681
            for i, output_size in enumerate(self.output_sizes):
                shard_offsets.append((i, current_shard_offset, output_size))
                current_shard_offset += output_size
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
682
                # Special case for Quantization.
683
684
685
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
686
687
                    shard_size = shard_size // param.packed_factor
                    shard_offset = shard_offset // param.packed_factor
688
                    # Special case for Marlin.
689
690
                    shard_size, shard_offset = adjust_marlin_shard(
                        param, shard_size, shard_offset)
691

692
693
694
                shard_size, shard_offset = adjust_bitblas_shard(
                    param, shard_size, shard_offset)

695
                if use_bitsandbytes_4bit:
696
697
698
699
700
701
702
703
704
                    index = list(itertools.accumulate([0] + self.output_sizes))
                    orig_offsets = {
                        str(i): (index[i], size)
                        for i, size in enumerate(self.output_sizes)
                    }
                    orig_offsets["total"] = (self.output_size, 0)
                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
                        param, orig_offsets, str(shard_id))

705
706
707
708
709
710
711
                loaded_weight_shard = loaded_weight.narrow(
                    output_dim, shard_offset, shard_size)
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id < len(self.output_sizes)
        if output_dim is not None:
712
713
714
            shard_offset = (sum(self.output_sizes[:loaded_shard_id]) //
                            self.tp_size)
            shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
715
            # Special case for quantization.
716
717
718
719
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
720
721
                shard_size = shard_size // param.packed_factor
                shard_offset = shard_offset // param.packed_factor
722
                # Special case for Marlin.
723
724
                shard_size, shard_offset = adjust_marlin_shard(
                    param, shard_size, shard_offset)
725
726
            shard_size, shard_offset = adjust_bitblas_shard(
                param, shard_size, shard_offset)
727

728
729
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)
730
731
732
733
734
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

735
            if use_bitsandbytes_4bit:
736
737
738
739
                shard_size = loaded_weight.shape[output_dim]
                shard_offset = loaded_weight.shape[output_dim] * \
                    loaded_shard_id

740
741
            param_data = param_data.narrow(output_dim, shard_offset,
                                           shard_size)
742
            start_idx = self.tp_rank * shard_size
743
            if not is_sharded_weight:
744
745
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)
746
747
748
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
749
750
                param_data, loaded_weight, loaded_shard_id)

751
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
752
753
754
755
756
757
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "MergedColumnParallelLinear, assume the weight is "
                    "the same for all partitions.")
758

759
760
761
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

762
763
764
765
766
    def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
                                           loaded_weight: torch.Tensor):
        """
        Handle special case for models where MLP layers are already
        fused on disk. In this case, we have no shard id. This function
767
        determines the shard id by splitting these layers and then calls
768
769
770
771
772
773
774
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """

        current_shard_offset = 0
775
        shard_offsets: list[tuple[int, int, int]] = []
776
777
778
779
780
781
782
783
        for i, output_size in enumerate(self.output_sizes):
            shard_offsets.append((i, current_shard_offset, output_size))
            current_shard_offset += output_size

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
784
785
            if isinstance(param, (PackedColumnParameter, PackedvLLMParameter
                                  )) and param.packed_dim == param.output_dim:
786
787
                shard_size, shard_offset = \
                    param.adjust_shard_indexes_for_packing(
788
789
790
791
792
793
794
795
796
797
798
799
                    shard_size=shard_size, shard_offset=shard_offset)

            loaded_weight_shard = loaded_weight.narrow(param.output_dim,
                                                       shard_offset,
                                                       shard_size)
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

    def weight_loader_v2(self,
                         param: BasevLLMParameter,
                         loaded_weight: torch.Tensor,
                         loaded_shard_id: Optional[int] = None):
        if loaded_shard_id is None:
800
801
802
803
            if isinstance(param, PerTensorScaleParameter):
                param.load_merged_column_weight(loaded_weight=loaded_weight,
                                                shard_id=0)
                return
804
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
805
                param.load_merged_column_weight(loaded_weight=loaded_weight)
806
                return
807
            # TODO: @dsikka - move to parameter.py
808
809
810
811
812
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id < len(self.output_sizes)

813
814
        if isinstance(param, BlockQuantScaleParameter):
            assert self.quant_method is not None
815
816
817
            # Assume the weight block size has been set by quant method
            assert hasattr(self, "weight_block_size")
            weight_block_size = self.weight_block_size
818
819
820
821
            assert weight_block_size is not None
            block_n, _ = weight_block_size[0], weight_block_size[1]
            shard_offset = (
                (sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) //
822
                block_n) // self.tp_size
823
            shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) //
824
                          block_n // self.tp_size)
825
        else:
826
827
828
            shard_offset = sum(
                self.output_sizes[:loaded_shard_id]) // self.tp_size
            shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
829
830
831
832

        param.load_merged_column_weight(loaded_weight=loaded_weight,
                                        shard_id=loaded_shard_id,
                                        shard_offset=shard_offset,
833
834
                                        shard_size=shard_size,
                                        tp_rank=self.tp_rank)
835

836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857

class QKVParallelLinear(ColumnParallelLinear):
    """Linear layers for the attention's QKV transformation.

    Linear layers for the linear transformation of the query, key, and value
    vectors in the attention layer. The weight matrix is concatenated along
    the output dimension. The layer is parallelized along the head dimension.
    When the number of key/value heads is smaller than the number of query
    heads (e.g., multi-query/grouped-query attention), the key/value head may
    be replicated while the query heads are partitioned.

    Args:
        hidden_size: input hidden state size of the transformer.
        head_size: size of each attention head.
        total_num_heads: total number of attention query heads.
        total_num_kv_heads: total number of attention key/value heads. If
                            None, assume total_num_kv_heads = total_num_heads.
        bias: If true, add bias.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
858
        quant_config: Quantization configure.
859
860
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
861
        return_bias: If true, return bias together with outputs in forward pass.
862
        disable_tp: If true, weights matrix won't be sharded through tp rank.
863
864
    """

865
866
867
868
869
870
871
872
873
874
875
876
877
    def __init__(
        self,
        hidden_size: int,
        head_size: int,
        total_num_heads: int,
        total_num_kv_heads: Optional[int] = None,
        bias: bool = True,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        *,
        return_bias: bool = True,
878
        disable_tp: bool = False,
879
    ):
880
881
882
883
884
885
886
        self.hidden_size = hidden_size
        self.head_size = head_size
        self.total_num_heads = total_num_heads
        if total_num_kv_heads is None:
            total_num_kv_heads = total_num_heads
        self.total_num_kv_heads = total_num_kv_heads
        # Divide the weight matrix along the last dimension.
887
888
        tp_size = (get_tensor_model_parallel_world_size()
                   if not disable_tp else 1)
889
890
891
892
893
894
895
896
897
898
899
        self.num_heads = divide(self.total_num_heads, tp_size)
        if tp_size >= self.total_num_kv_heads:
            self.num_kv_heads = 1
            self.num_kv_head_replicas = divide(tp_size,
                                               self.total_num_kv_heads)
        else:
            self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
            self.num_kv_head_replicas = 1
        input_size = self.hidden_size
        output_size = (self.num_heads +
                       2 * self.num_kv_heads) * tp_size * self.head_size
900
901
902
903
        self.output_sizes = [
            self.num_heads * self.head_size * tp_size,  # q_proj
            self.num_kv_heads * self.head_size * tp_size,  # k_proj
            self.num_kv_heads * self.head_size * tp_size,  # v_proj 
James Fleming's avatar
James Fleming committed
904
905
        ]

906
907
908
909
910
911
        super().__init__(input_size=input_size,
                         output_size=output_size,
                         bias=bias,
                         gather_output=False,
                         skip_bias_add=skip_bias_add,
                         params_dtype=params_dtype,
912
                         quant_config=quant_config,
913
                         prefix=prefix,
914
915
                         return_bias=return_bias,
                         disable_tp=disable_tp)
916

917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
    def _get_shard_offset_mapping(self, loaded_shard_id: str):
        shard_offset_mapping = {
            "q": 0,
            "k": self.num_heads * self.head_size,
            "v": (self.num_heads + self.num_kv_heads) * self.head_size,
            "total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size
        }
        return shard_offset_mapping.get(loaded_shard_id)

    def _get_shard_size_mapping(self, loaded_shard_id: str):
        shard_size_mapping = {
            "q": self.num_heads * self.head_size,
            "k": self.num_kv_heads * self.head_size,
            "v": self.num_kv_heads * self.head_size,
        }
        return shard_size_mapping.get(loaded_shard_id)

    def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
                                           loaded_weight: torch.Tensor):
        """
        Handle special case for models where QKV layers are already 
        fused on disk. In this case, we have no shard id. This function
939
        determines the shard id by splitting these layers and then calls
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """
        shard_offsets = [
            # (shard_id, shard_offset, shard_size)
            ("q", 0, self.total_num_heads * self.head_size),
            ("k", self.total_num_heads * self.head_size,
             self.total_num_kv_heads * self.head_size),
            ("v",
             (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
             self.total_num_kv_heads * self.head_size),
        ]

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
959
960
            if isinstance(param, (PackedColumnParameter, PackedvLLMParameter
                                  )) and param.packed_dim == param.output_dim:
961
962
                shard_size, shard_offset = \
                    param.adjust_shard_indexes_for_packing(
963
964
965
966
967
968
969
970
971
972
973
974
                    shard_size=shard_size, shard_offset=shard_offset)

            loaded_weight_shard = loaded_weight.narrow(param.output_dim,
                                                       shard_offset,
                                                       shard_size)
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

    def weight_loader_v2(self,
                         param: BasevLLMParameter,
                         loaded_weight: torch.Tensor,
                         loaded_shard_id: Optional[str] = None):
        if loaded_shard_id is None:  # special case for certain models
975
            if isinstance(param, PerTensorScaleParameter):
976
977
978
                param.load_qkv_weight(loaded_weight=loaded_weight,
                                      shard_id=0,
                                      tp_rank=self.tp_rank)
979
                return
980
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
981
982
                param.load_qkv_weight(loaded_weight=loaded_weight,
                                      tp_rank=self.tp_rank)
983
                return
984
            # TODO: @dsikka - move to parameter.py
985
986
987
988
989
990
991
992
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id in ["q", "k", "v"]

        shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
        shard_size = self._get_shard_size_mapping(loaded_shard_id)

993
994
995
        # Note(simon): This is needed for Qwen3's fp8 quantization.
        if isinstance(param, BlockQuantScaleParameter):
            assert self.quant_method is not None
996
997
998
999
            # Assume the weight block size has been set by quant method
            assert hasattr(self, "weight_block_size")
            weight_block_size = self.weight_block_size
            assert weight_block_size is not None
1000
1001
1002
1003
            block_n, _ = weight_block_size[0], weight_block_size[1]
            shard_offset = (shard_offset + block_n - 1) // block_n
            shard_size = (shard_size + block_n - 1) // block_n

1004
1005
1006
1007
        param.load_qkv_weight(loaded_weight=loaded_weight,
                              num_heads=self.num_kv_head_replicas,
                              shard_id=loaded_shard_id,
                              shard_offset=shard_offset,
1008
1009
                              shard_size=shard_size,
                              tp_rank=self.tp_rank)
1010

1011
1012
1013
1014
    def weight_loader(self,
                      param: Parameter,
                      loaded_weight: torch.Tensor,
                      loaded_shard_id: Optional[str] = None):
1015
1016
1017
1018
1019

        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
1020
        if is_gguf_weight_type:
1021
            idx_map = {"q": 0, "k": 1, "v": 2}
1022
1023
1024
1025
1026
1027
1028
1029
            if loaded_shard_id is not None:
                param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
                param.shard_weight_type = {
                    k: loaded_weight.item()
                    for k in idx_map
                }
1030
1031
            return

1032
1033
        if is_gguf_weight:
            output_dim = getattr(param, "output_dim", None)
1034
1035
            shard_size = loaded_weight.size(output_dim) // self.tp_size
            start_idx = self.tp_rank * shard_size
1036

1037
1038
1039
1040
1041
1042
1043
            if loaded_shard_id is not None:
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
1044

1045
1046
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
1047

1048
1049
        # Special case for per-tensor scales in fused case.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
1050

1051
        if loaded_shard_id is None:
1052
1053
            # Loaded weight is already fused on disk (qkv).
            # (e.g., Phi-3's qkv_proj).
1054
            if output_dim is None:
1055
                if needs_scalar_to_array:
1056
1057
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
                        param_data, loaded_weight, 0)
1058

1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            shard_offsets = [
                # (shard_id, shard_offset, shard_size)
                ("q", 0, self.total_num_heads * self.head_size),
                ("k", self.total_num_heads * self.head_size,
                 self.total_num_kv_heads * self.head_size),
                ("v", (self.total_num_heads + self.total_num_kv_heads) *
                 self.head_size, self.total_num_kv_heads * self.head_size),
            ]
1070
1071
1072
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)

1073
1074
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
1075
                # Special case for Quantized Weights.
1076
1077
1078
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
1079
1080
                    shard_size = shard_size // param.packed_factor
                    shard_offset = shard_offset // param.packed_factor
1081

1082
                    # Special case for Marlin.
1083
1084
1085
                    shard_size, shard_offset = adjust_marlin_shard(
                        param, shard_size, shard_offset)

1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
                if use_bitsandbytes_4bit:
                    orig_qkv_offsets = {
                        "q": (0, self.total_num_heads * self.head_size),
                        "k": (self.total_num_heads * self.head_size,
                              self.total_num_kv_heads * self.head_size),
                        "v":
                        ((self.total_num_heads + self.total_num_kv_heads) *
                         self.head_size,
                         self.total_num_kv_heads * self.head_size),
                        "total":
                        ((self.total_num_heads + 2 * self.total_num_kv_heads) *
                         self.head_size, 0)
                    }

                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
                        param, orig_qkv_offsets, shard_id)

1103
1104
1105
1106
1107
1108
                loaded_weight_shard = loaded_weight.narrow(
                    output_dim, shard_offset, shard_size)
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id in ["q", "k", "v"]
1109
1110

        # If output dim is defined, use the default loading process.
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
        if output_dim is not None:
            if loaded_shard_id == "q":
                shard_offset = 0
                shard_size = self.num_heads * self.head_size
            elif loaded_shard_id == "k":
                shard_offset = self.num_heads * self.head_size
                shard_size = self.num_kv_heads * self.head_size
            elif loaded_shard_id == "v":
                shard_offset = (self.num_heads +
                                self.num_kv_heads) * self.head_size
                shard_size = self.num_kv_heads * self.head_size
1122
            # Special case for Quantized Weights.
1123
1124
1125
1126
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
1127
1128
                shard_size = shard_size // param.packed_factor
                shard_offset = shard_offset // param.packed_factor
1129

1130
                # Special case for Marlin.
1131
1132
1133
                shard_size, shard_offset = adjust_marlin_shard(
                    param, shard_size, shard_offset)

1134
1135
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)
1136
1137
1138
1139
1140
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

1141
            if use_bitsandbytes_4bit:
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
                orig_qkv_offsets = {
                    "q": (0, self.num_heads * self.head_size),
                    "k": (self.num_heads * self.head_size,
                          self.num_kv_heads * self.head_size),
                    "v":
                    ((self.num_heads + self.num_kv_heads) * self.head_size,
                     self.num_kv_heads * self.head_size),
                    "total":
                    ((self.num_heads + 2 * self.num_kv_heads) * self.head_size,
                     0)
                }
1153
                shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
1154
1155
                    param, orig_qkv_offsets, loaded_shard_id)

1156
1157
            param_data = param_data.narrow(output_dim, shard_offset,
                                           shard_size)
1158
            if loaded_shard_id == "q":
1159
                shard_id = self.tp_rank
1160
            else:
1161
                shard_id = self.tp_rank // self.num_kv_head_replicas
1162
            start_idx = shard_id * shard_size
1163

1164
            if not is_sharded_weight:
1165
1166
1167
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)

1168
1169
1170
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
1171
                param_data, loaded_weight, loaded_shard_id)
1172
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
1173
1174
1175
1176
1177
1178
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "QKVParallelLinear, assume the weight is the same "
                    "for all partitions.")
1179

1180
1181
1182
1183
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)


1184
@CustomOp.register("row_parallel_linear")
1185
class RowParallelLinear(LinearBase):
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
    """Linear layer with row parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its first dimension and X along its second dimension as:
               -   -
              | A_1 |
              | .   |
          A = | .   |        X = [X_1, ..., X_p]
              | .   |
              | A_p |
               -   -
    Arguments:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias. Note that bias is not parallelized.
        input_is_parallel: If true, we assume that the input is already
                           split across the GPUs and we do not split
                           again.
        skip_bias_add: This was added to enable performance optimization where
                       bias can be fused with other element-wise operations.
                       We skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
1208
1209
1210
        reduce_results: If true, call all-reduce on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y = X_iA_i
1211
        quant_config: Quantization configure.
1212
1213
1214
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.down_proj)
        return_bias: If true, return bias together with outputs in forward pass.
1215
        disable_tp: If true, weights matrix won't be sharded through tp rank.
1216
1217
    """

1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        input_is_parallel: bool = True,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        reduce_results: bool = True,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        *,
        return_bias: bool = True,
1231
        disable_tp: bool = False,
1232
    ):
1233
        # Divide the weight matrix along the first dimension.
1234
1235
1236
1237
        self.tp_rank = (get_tensor_model_parallel_rank()
                        if not disable_tp else 0)
        self.tp_size = (get_tensor_model_parallel_world_size()
                        if not disable_tp else 1)
1238
1239
1240
1241
        self.input_size_per_partition = divide(input_size, self.tp_size)
        self.output_size_per_partition = output_size
        self.output_partition_sizes = [output_size]

1242
1243
1244
1245
1246
1247
        super().__init__(input_size,
                         output_size,
                         skip_bias_add,
                         params_dtype,
                         quant_config,
                         prefix,
1248
1249
                         return_bias=return_bias,
                         disable_tp=disable_tp)
1250

1251
1252
1253
        self.input_is_parallel = input_is_parallel
        self.reduce_results = reduce_results

1254
        assert self.quant_method is not None
1255
1256
1257
        self.quant_method.create_weights(
            layer=self,
            input_size_per_partition=self.input_size_per_partition,
1258
            output_partition_sizes=self.output_partition_sizes,
1259
1260
1261
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
1262
1263
            weight_loader=(
                self.weight_loader_v2 if self.quant_method.__class__.__name__
1264
                in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
1265
1266
1267
1268
1269
1270
        if not reduce_results and (bias and not skip_bias_add):
            raise ValueError("When not reduce the results, adding bias to the "
                             "results can lead to incorrect results")

        if bias:
            self.bias = Parameter(
1271
                torch.empty(self.output_size, dtype=params_dtype))
1272
1273
1274
1275
1276
1277
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
        else:
            self.register_parameter("bias", None)
1278
        self.update_param_tp_status()
1279
1280
1281

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        input_dim = getattr(param, "input_dim", None)
1282
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1283
1284
1285
1286
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
1287
1288
1289
1290
1291
1292
1293
1294
1295

        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
1296
1297
            weight_shape = list(loaded_weight.shape)
            if input_dim:
1298
1299
                weight_shape[input_dim] = (weight_shape[input_dim] //
                                           self.tp_size)
1300
            param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
1301

1302
        param_data = param.data
1303
        if input_dim is not None and not is_sharded_weight:
1304
            shard_size = param_data.shape[input_dim]
1305
            start_idx = self.tp_rank * shard_size
1306
1307
            loaded_weight = loaded_weight.narrow(input_dim, start_idx,
                                                 shard_size)
1308

1309
1310
1311
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
1312
1313
            loaded_weight = loaded_weight.reshape(1)

1314
1315
1316
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

1317
1318
    def weight_loader_v2(self, param: BasevLLMParameter,
                         loaded_weight: torch.Tensor):
1319
1320
1321
1322
1323
1324
1325

        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)

1326
1327
        param.load_row_parallel_weight(loaded_weight=loaded_weight)

1328
1329
1330
    def forward(
        self, input_
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
1331
1332
1333
1334
1335
        if self.input_is_parallel:
            input_parallel = input_
        else:
            splitted_input = split_tensor_along_last_dim(
                input_, num_partitions=self.tp_size)
1336
            input_parallel = splitted_input[self.tp_rank].contiguous()
1337
1338

        # Matrix multiply.
1339
        assert self.quant_method is not None
1340
1341
1342
1343
1344
1345
        # Only fuse bias add into GEMM for rank 0 (this ensures that
        # bias will not get added more than once in TP>1 case)
        bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
        output_parallel = self.quant_method.apply(self,
                                                  input_parallel,
                                                  bias=bias_)
1346
        if self.reduce_results and self.tp_size > 1:
1347
            output = tensor_model_parallel_all_reduce(output_parallel)
1348
        else:
1349
1350
1351
            output = output_parallel

        output_bias = self.bias if self.skip_bias_add else None
1352

1353
1354
        if not self.return_bias:
            return output
1355
        return output, output_bias
1356
1357

    def extra_repr(self) -> str:
1358
        s = f"in_features={self.input_size_per_partition}"
1359
1360
1361
1362
1363
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={self.tp_size}"
        s += f", reduce_results={self.reduce_results}"
        return s
1364
1365


1366
@CustomOp.register("qkv_cross_parallel_linear")
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
class QKVCrossParallelLinear(LinearBase):
    """Linear layers for efficient cross-attention's QKV transformation.

    Args:
        hidden_size: input hidden state size of the transformer.
        head_size: size of each attention head.
        total_num_heads: total number of attention query heads.
        total_num_kv_heads: total number of attention key/value heads. If
                            None, assume total_num_kv_heads = total_num_heads.
        bias: If true, add bias.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
        quant_config: Quantization configure.
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
    """
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395

    def __init__(self,
                 hidden_size: int,
                 head_size: int,
                 total_num_heads: int,
                 total_num_kv_heads: Optional[int] = None,
                 bias: bool = True,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
        # input_size and output_size are not used, just for alignment
        input_size = hidden_size
        output_size = (total_num_heads + (total_num_kv_heads or 0)) * head_size
        super().__init__(input_size=input_size,
                         output_size=output_size,
                         skip_bias_add=skip_bias_add,
                         params_dtype=params_dtype,
                         quant_config=quant_config,
                         prefix=prefix)

        self.quant_config = quant_config

1408
        # Empty placeholders for loading as a single module.
1409
1410
1411
1412
1413
1414
1415
1416
1417
        placeholder_size = 0
        assert self.quant_method is not None
        self.quant_method.create_weights(self,
                                         placeholder_size, [placeholder_size],
                                         placeholder_size,
                                         placeholder_size,
                                         self.params_dtype,
                                         weight_loader=self.weight_loader)

1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
        # Use a dictionary to avoid submodules parameters auto-registration:
        # drop-in replacement for a `QKVParallelLinear` module.
        self.proj = dict()
        self.proj["q_proj_decoder"] = ColumnParallelLinear(
            input_size=hidden_size,
            output_size=total_num_heads * head_size,
            bias=bias,
            quant_config=quant_config,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            prefix=f"{prefix}.q_proj_decoder")

        self.proj["kv_proj_encoder"] = QKVParallelLinear(
            hidden_size=hidden_size,
            head_size=head_size,
            total_num_heads=0,
            total_num_kv_heads=total_num_kv_heads,
            bias=bias,
            quant_config=quant_config,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            prefix=f"{prefix}.kv_proj_encoder")

        # `kv_proj_encoder.num_kv_heads` accounts for sharding with tp>1.
1442
        self.q_size = self.q_proj_decoder.output_size_per_partition
1443
1444
1445
1446
1447
        self.kv_size = self.kv_proj_encoder.num_kv_heads * head_size

        if bias:
            self.bias = torch.nn.Parameter()
            set_weight_attrs(self.bias, {
1448
                "output_dim": 0,
1449
                "weight_loader": self.weight_loader_v1,
1450
            })
1451
1452
        else:
            self.bias = None
1453

1454
1455
1456
1457
1458
    def process_weights_after_loading(self):
        for layer in self.proj.values():
            if self.quant_method is not None:
                self.quant_method.process_weights_after_loading(layer)

1459
    @property
1460
1461
1462
    def q_proj_decoder(self) -> ColumnParallelLinear:
        layer = self.proj["q_proj_decoder"]
        for name, param in self.named_parameters():
1463
1464
1465
1466
1467
            target_param = getattr(layer, name, None)
            if target_param is not None:
                self.sync_weight_attrs(param,
                                       target_param,
                                       mode="q_proj_decoder")
1468
        return layer
1469
1470

    @property
1471
1472
1473
    def kv_proj_encoder(self) -> QKVParallelLinear:
        layer = self.proj["kv_proj_encoder"]
        for name, param in self.named_parameters():
1474
1475
1476
1477
1478
            target_param = getattr(layer, name, None)
            if target_param is not None:
                self.sync_weight_attrs(param,
                                       target_param,
                                       mode="kv_proj_encoder")
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
        return layer

    def sync_weight_attrs(
        self,
        src_param: nn.Parameter,
        tgt_param: nn.Parameter,
        mode: Literal["q_proj_decoder", "kv_proj_encoder"],
    ):
        missing_attrs_dict = {
            k: getattr(src_param, k)
1489
1490
            for k in (set(vars(src_param).keys()) -
                      set(vars(tgt_param).keys()))
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
        }
        # TODO(Isotr0py): handle bitsandbytes 8bit
        use_bitsandbytes_4bit = getattr(src_param, "use_bitsandbytes_4bit",
                                        False)
        if (missing_attrs_dict and use_bitsandbytes_4bit):
            q_proj_attrs, kv_proj_attrs = left_shift_bitsandbytes_4bit_shard(
                missing_attrs_dict)
            if mode == "q_proj_decoder":
                set_weight_attrs(tgt_param, q_proj_attrs)
            elif mode == "kv_proj_encoder":
                set_weight_attrs(tgt_param, kv_proj_attrs)
        else:
            set_weight_attrs(tgt_param, missing_attrs_dict)
1504

1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
    def _is_same_param(
        self,
        src_param: torch.nn.Parameter,
        map_param: torch.nn.Parameter,
    ) -> bool:
        """Check if two parameters are exactly pointing to same things."""
        # ignore weight_loader because it's always different
        key_to_ignore = ["weight_loader", "_weight_loader"]
        has_same_type_name = type(src_param) is type(map_param)
        src_param_attrs = {
            k: v
            for k, v in src_param.__dict__.items() if k not in key_to_ignore
        }
        map_param_attrs = {
            k: v
            for k, v in map_param.__dict__.items() if k not in key_to_ignore
        }
        has_same_attrs = src_param_attrs == map_param_attrs
        return has_same_type_name and has_same_attrs

    def select_proj_params(
        self,
        layer: nn.Module,
        param: nn.Parameter,
    ) -> nn.Parameter:
        """
        Given the placeholder param, 
        return the corresponding param in the proj layers.
        """
        target_param_list = [
            v for _, v in layer.named_parameters()
            if self._is_same_param(param, v)
        ]
        assert len(target_param_list) == 1
        target_param = target_param_list[0]
        return target_param

    def forward(  # type: ignore[override]
        self,
        decoder_hidden_states: torch.Tensor,
        encoder_hidden_states: torch.Tensor,
    ) -> tuple[torch.Tensor, ...]:
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
        q, _ = self.q_proj_decoder(decoder_hidden_states)
        if encoder_hidden_states is None:
            # Encoder KV already cached.
            k = None
            v = None
        else:
            # Prefill phase, encoder KV cached here.
            kv_enc, _ = self.kv_proj_encoder(encoder_hidden_states)
            # Split kv in half
            k, v = kv_enc.split(self.kv_size, dim=-1)
        return q, k, v

1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
    def weight_loader_v1(self,
                         param: torch.nn.Parameter,
                         loaded_weight: torch.Tensor,
                         loaded_shard_id: Optional[str] = None):
        # just like all other parameters, does not yet
        # support loading bias with weight_loader_v2
        layer = (self.q_proj_decoder
                 if loaded_shard_id == "q" else self.kv_proj_encoder)
        target_param = self.select_proj_params(layer, param)
        shard_id_args = (loaded_shard_id, ) if loaded_shard_id != "q" else ()
        layer.weight_loader(target_param, loaded_weight, *shard_id_args)

1571
1572
1573
1574
1575
1576
1577
1578
    def weight_loader(self,
                      param: torch.nn.Parameter,
                      loaded_weight: torch.Tensor,
                      loaded_shard_id: Optional[str] = None):
        layer = (self.q_proj_decoder
                 if loaded_shard_id == "q" else self.kv_proj_encoder)
        target_param = self.select_proj_params(layer, param)
        shard_id_args = (loaded_shard_id, ) if loaded_shard_id != "q" else ()
1579
1580
1581
1582
        if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED:
            layer.weight_loader_v2(target_param, loaded_weight, *shard_id_args)
        else:
            layer.weight_loader(target_param, loaded_weight, *shard_id_args)
1583
1584
1585

    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
1586
        s += f", q_size={self.q_size}"
1587
1588
1589
1590
1591
        s += f", kv_size={self.kv_size}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={get_tensor_model_parallel_world_size()}"
        s += ", gather_output=False"
        return s