linear.py 60.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
from abc import abstractmethod
6

7
from typing import Any
8
from vllm import envs
9
10

import torch
11
from torch.nn.parameter import Parameter, UninitializedParameter
12

13
14
15
16
17
18
19
20
from vllm.distributed import (
    divide,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
    split_tensor_along_last_dim,
    tensor_model_parallel_all_gather,
    tensor_model_parallel_all_reduce,
)
21
from vllm.logger import init_logger
22
from vllm.model_executor.custom_op import CustomOp
23
from vllm.model_executor.layers.quantization.base_config import (
24
25
26
    QuantizationConfig,
    QuantizeMethodBase,
)
27
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
28
29
30
31
32
33
34
35
36
from vllm.model_executor.parameter import (
    BasevLLMParameter,
    BlockQuantScaleParameter,
    ModelWeightParameter,
    PackedColumnParameter,
    PackedvLLMParameter,
    PerTensorScaleParameter,
    RowvLLMParameter,
)
37
from vllm.model_executor.utils import set_weight_attrs
38
from vllm.platforms import current_platform
gaoqiong's avatar
gaoqiong committed
39

zhuwenwen's avatar
zhuwenwen committed
40
import os
41
from vllm.model_executor.utils import gemm_bank_conf
42
43
44

logger = init_logger(__name__)

45
WEIGHT_LOADER_V2_SUPPORTED = [
46
    "UnquantizedLinearMethod",
47
    "CompressedTensorsLinearMethod",
48
    "CompressedTensorsLinearTransformMethod",
49
50
    "BitBLASLinearMethod",
    "GPTQBitBLASLinearMethod",
51
52
53
54
55
56
57
58
59
60
    "AWQMarlinLinearMethod",
    "AWQLinearMethod",
    "GPTQMarlinLinearMethod",
    "Fp8LinearMethod",
    "MarlinLinearMethod",
    "GPTQMarlin24LinearMethod",
    "TPUInt8LinearMethod",
    "GPTQLinearMethod",
    "FBGEMMFp8LinearMethod",
    "ModelOptFp8LinearMethod",
61
62
    "ModelOptFp8PcPtLinearMethod",
    "ModelOptFp8PbWoLinearMethod",
63
64
65
66
67
    "IPEXAWQLinearMethod",
    "IPEXGPTQLinearMethod",
    "HQQMarlinMethod",
    "QuarkLinearMethod",
    "ModelOptNvFp4LinearMethod",
68
    "PetitNvFp4LinearMethod",
zhuwenwen's avatar
zhuwenwen committed
69
    "BlockInt8LinearMethod",
70
]
71

72

73
74
75
def adjust_bitblas_shard(param, shard_size, shard_offset):
    bitblas_tile_size = getattr(param, "bitblas_tile_size", None)
    if bitblas_tile_size is not None:
76
        return (shard_size // bitblas_tile_size, shard_offset // bitblas_tile_size)
77
78
79
80

    return shard_size, shard_offset


81
82
83
84
85
86
87
88
def adjust_marlin_shard(param, shard_size, shard_offset):
    marlin_tile_size = getattr(param, "marlin_tile_size", None)
    if marlin_tile_size is None:
        return shard_size, shard_offset

    return shard_size * marlin_tile_size, shard_offset * marlin_tile_size


89
90
91
92
93
94
95
96
def adjust_block_scale_shard(weight_block_size, shard_size, shard_offset):
    assert weight_block_size is not None
    block_n = weight_block_size[0]
    shard_offset = (shard_offset + block_n - 1) // block_n
    shard_size = (shard_size + block_n - 1) // block_n
    return shard_size, shard_offset


97
98
99
def adjust_bitsandbytes_4bit_shard(
    param: Parameter, shard_offsets: dict[str, tuple[int, int]], loaded_shard_id: str
) -> tuple[int, int]:
100
101
    """Adjust the quantization offsets and sizes for BitsAndBytes sharding."""

102
103
    total, _ = shard_offsets["total"]
    orig_offset, orig_size = shard_offsets[loaded_shard_id]
104
105
106
107
108
109
110
111

    quantized_total = param.data.shape[0]
    quantized_offset = orig_offset * quantized_total // total
    quantized_size = orig_size * quantized_total // total

    return quantized_size, quantized_offset


112
113
114
def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
    """For fused modules (QKV and MLP) we have an array of length
    N that holds 1 scale for each "logical" matrix. So the param
115
116
    is an array of length N. The loaded_weight corresponds to
    one of the shards on disk. Here, we slice the param based on
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
    the shard_id for loading.
    """
    qkv_idxs = {"q": 0, "k": 1, "v": 2}

    if isinstance(shard_id, str):
        shard_id = qkv_idxs[shard_id]
    elif not isinstance(shard_id, int):
        raise ValueError(f"Unknown Shard Id {shard_id}")

    # AutoFP8 scales do not have a shape
    # compressed-tensors scales do have a shape
    if len(loaded_weight.shape) != 0:
        assert loaded_weight.shape[0] == 1
        loaded_weight = loaded_weight[0]

132
133
134
135
    if envs.VLLM_USE_NN:
        return param[shard_id], loaded_weight.t()
    else:
        return param[shard_id], loaded_weight
136
137


138
139
140
141
142
143
144
145
# TODO(Isotr0py): We might need a more flexible structure to handle
# bitsandbytes shard offsets.
def left_shift_bitsandbytes_4bit_shard(bnb_weight_attrs: dict[str, Any]):
    """
    Separate the BitsAndBytes 4-bit shard.

    For example, given bnb weight attributes as below:
    {
146
        'bnb_shard_offsets': array([0, 4, 8, 16]),
147
148
149
150
151
        'bnb_quant_state': {0: ..., 1: ..., 2: ...},
    }

    The function will return:
    {
152
        'bnb_shard_offsets': array([0, 4]),
153
154
155
156
157
158
159
160
161
162
163
164
165
166
        'bnb_quant_state': {0: ...},
    }
    and
    {
        'bnb_shard_offsets': array([0, 4, 12]),
        'bnb_quant_state': {0: ..., 1: ...},
    }
    """
    shard_offsets = bnb_weight_attrs["bnb_shard_offsets"]
    offset_l = shard_offsets[:2]
    offset_r = shard_offsets[1:] - shard_offsets[1]
    quant_state_l = {0: bnb_weight_attrs["bnb_quant_state"][0]}
    quant_state_r = {
        i - 1: bnb_weight_attrs["bnb_quant_state"][i]
167
        for i in range(1, len(shard_offsets) - 1)
168
169
170
171
172
173
    }
    left = dict(bnb_shard_offsets=offset_l, bnb_quant_state=quant_state_l)
    right = dict(bnb_shard_offsets=offset_r, bnb_quant_state=quant_state_r)
    return left, right


174
class LinearMethodBase(QuantizeMethodBase):
175
176
177
    """Base class for different (maybe quantized) linear methods."""

    @abstractmethod
178
179
180
181
182
183
184
185
186
187
188
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        """Create weights for a linear layer.
189
           The weights will be set as attributes of the layer.
190

191
192
193
        Args:
            layer: The layer that is using the LinearMethodBase factory.
            input_size_per_partition: Size of the weight input dim on rank X.
194
            output_partition_sizes: Sizes of the output dim of each logical
195
196
197
198
199
200
                weight on rank X. E.g., output_partition_sizes for QKVLinear
                is a list contains the width of Wq, Wk, Wv on rank X.
            input_size: Size of the input dim of the weight across all ranks.
            output_size: Size of the output dim of the weight across all ranks.
            params_dtype: Datatype of the parameters.
        """
201
202
203
        raise NotImplementedError

    @abstractmethod
204
205
206
207
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
208
        bias: torch.Tensor | None = None,
209
    ) -> torch.Tensor:
210
211
        """Apply the weights in layer to the input tensor.
        Expects create_weights to have been called before on the layer."""
212
213
214
215
        raise NotImplementedError


class UnquantizedLinearMethod(LinearMethodBase):
216
    """Linear method without quantization."""
217
218
    
    def __init__(self):
zhuwenwen's avatar
zhuwenwen committed
219
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
220
        self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
gaoqiong's avatar
gaoqiong committed
221
        
222

223

224
225
226
227
228
229
230
231
232
233
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
234
235
236
237
        # This method creates unquantized linear weights.
        # The weights are not quantized, and they are not sharded.
        # The amount of memory allocated for the weights is
        # sum(output_partition_sizes) * input_size_per_partition.
238
        weight_loader = extra_weight_attrs.pop("weight_loader")
239
240
241
        if envs.VLLM_USE_NN:
            weight = ModelWeightParameter(
                data=torch.empty(
242
243
                    input_size_per_partition,
                    sum(output_partition_sizes),
244
245
246
247
248
249
250
251
252
                    dtype=params_dtype,
                ),
                input_dim=1,
                output_dim=0,
                weight_loader=weight_loader,
            )
        else:
            weight = ModelWeightParameter(
                data=torch.empty(
253
254
                    sum(output_partition_sizes),
                    input_size_per_partition,
255
256
257
258
259
260
                    dtype=params_dtype,
                ),
                input_dim=1,
                output_dim=0,
                weight_loader=weight_loader,
            )
261

262
263
        layer.register_parameter("weight", weight)
        set_weight_attrs(weight, extra_weight_attrs)
264

265
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
266
        if current_platform.is_cpu():
267
            from vllm.model_executor.layers.utils import dispatch_cpu_unquantized_gemm
268

269
            dispatch_cpu_unquantized_gemm(layer, remove_weight=True)
270

271
272
273
274
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
275
        bias: torch.Tensor | None = None,
276
    ) -> torch.Tensor:
zhuwenwen's avatar
zhuwenwen committed
277
        if self.use_llama_nn:
zhuwenwen's avatar
zhuwenwen committed
278
279
            # if os.environ['GEMM_PAD'] == '1' and gemm_bank_conf(layer.weight.shape[1] - 32):
            #     layer.weight = layer.weight[:,:-32]
zhuwenwen's avatar
zhuwenwen committed
280
            if bias is not None:
zhuwenwen's avatar
zhuwenwen committed
281
                if len(x.shape) == 2: 
282
                    return torch.addmm(bias, x, layer.weight)
zhuwenwen's avatar
zhuwenwen committed
283
                else:
284
                    return torch.matmul(x, layer.weight) + bias
zhuwenwen's avatar
zhuwenwen committed
285
            else:
286
                return torch.matmul(x, layer.weight)
zhuwenwen's avatar
zhuwenwen committed
287
        else:
zhuwenwen's avatar
zhuwenwen committed
288
289
290
291
292
293
294
295
296
297
            # if envs.VLLM_USE_NN and x.shape[-1] == layer.weight.shape[0]:
            #     return dispatch_unquantized_gemm()(layer, x, layer.weight.t(), bias)
            if envs.VLLM_USE_NN:
                if bias is not None:
                    if len(x.shape) == 2: 
                        return torch.addmm(bias, x, layer.weight)
                    else:
                        return torch.matmul(x, layer.weight) + bias
                else:
                    return torch.matmul(x, layer.weight)
298
            else:
zhuwenwen's avatar
zhuwenwen committed
299
                return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
300

301

302
class LinearBase(CustomOp):
303
    """Base linear layer.
304
305
306
307
308
309

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
310
        quant_config: Quantization configure.
311
        prefix: Prefix for parameter names.
312
        return_bias: If true, return bias together with outputs in forward pass.
313
        disable_tp: If true, tensor parallelism will be disabled for this layer.
314
315
316
317
318
319
320
    """

    def __init__(
        self,
        input_size: int,
        output_size: int,
        skip_bias_add: bool = False,
321
322
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
323
        prefix: str = "",
324
325
        *,
        return_bias: bool = True,
326
        disable_tp: bool = False,
327
328
329
330
331
332
333
334
335
336
    ):
        super().__init__()

        # Keep input parameters
        self.input_size = input_size
        self.output_size = output_size
        self.skip_bias_add = skip_bias_add
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype
337
338
        self.quant_config = quant_config
        self.prefix = prefix
339
        self.allow_fp8_block_shape_mismatch = False
340
        if quant_config is None:
341
            self.quant_method: QuantizeMethodBase | None = UnquantizedLinearMethod()
342
        else:
343
            self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
344
        self.return_bias = return_bias
345
        self.disable_tp = disable_tp
346
347
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
348

349
    def update_param_tp_status(self):
350
351
352
353
        for param in self.parameters():
            if isinstance(param, BasevLLMParameter):
                param.tp_rank = self.tp_rank
                param.tp_size = self.tp_size
354
355


356
# --8<-- [start:replicated_linear]
357
@CustomOp.register("replicated_linear")
358
359
360
361
362
363
364
365
366
367
class ReplicatedLinear(LinearBase):
    """Replicated linear layer.

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        bias: If true, add bias.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
        quant_config: Quantization configure.
368
369
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
370
        return_bias: If true, return bias together with outputs in forward pass.
371
        disable_tp: Take no effect for replicated linear layers.
372
373
    """

374
375
    # --8<-- [end:replicated_linear]

376
377
378
379
380
381
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        skip_bias_add: bool = False,
382
383
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
384
        eps: float | None = 1e-6,
385
386
387
        prefix: str = "",
        *,
        return_bias: bool = True,
388
        disable_tp: bool = False,
389
    ):
390
391
392
393
394
        # If MergedReplicatedLinear, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = self.output_sizes
        else:
            self.output_partition_sizes = [output_size]
395
            
396
        self.eps = eps
397

398
399
400
401
402
403
404
405
406
407
        super().__init__(
            input_size,
            output_size,
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
408

409
410
        # All the linear layer supports quant method.
        assert self.quant_method is not None
411
412
413
414
415
416
417
418
419
        self.quant_method.create_weights(
            self,
            self.input_size,
            self.output_partition_sizes,
            self.input_size,
            self.output_size,
            self.params_dtype,
            weight_loader=self.weight_loader,
        )
420

421
422
        if bias:
            self.bias = Parameter(
423
424
425
426
427
428
429
430
431
                torch.empty(self.output_size, dtype=self.params_dtype)
            )
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
432
433
        else:
            self.register_parameter("bias", None)
434
        self.is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
435

436
437
438
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        # If the weight on disk does not have a shape, give it one
        # (such scales for AutoFp8).
439
440
441
442
443
444
445
446
447
448
449
        # Special case for GGUF

        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
            param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)

450
451
452
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)

453
        if envs.VLLM_USE_NN and not self.is_quantization:
454
455
            loaded_weight = loaded_weight.t()
            
456
457
        assert param.size() == loaded_weight.size(), (
            f"Tried to load weights of size {loaded_weight.size()}"
458
459
            f"to a parameter of size {param.size()}"
        )
460
461
        param.data.copy_(loaded_weight)

462
    def forward(
463
        self,
zhuwenwen's avatar
zhuwenwen committed
464
        x: torch.Tensor,
465
    ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
zhuwenwen's avatar
zhuwenwen committed
466
467
        bias = self.bias if not self.skip_bias_add else None
        assert self.quant_method is not None
468

zhuwenwen's avatar
zhuwenwen committed
469
        output = self.quant_method.apply(self, x, bias)
470

zhuwenwen's avatar
zhuwenwen committed
471
472
        if not self.return_bias:
            return output
473
        output_bias = self.bias if self.skip_bias_add else None
zhuwenwen's avatar
zhuwenwen committed
474
        return output, output_bias
475

476
477
478
479
480
481
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        return s

482

483
# --8<-- [start:column_parallel_linear]
484
@CustomOp.register("column_parallel_linear")
485
class ColumnParallelLinear(LinearBase):
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
    """Linear layer with column parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its second dimension as A = [A_1, ..., A_p].

    Args:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y_i = XA_i
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
502
        quant_config: Quantization configure.
503
        prefix: The name of the layer in the state dict, including all parents
504
                        (e.g. model.layers.0.qkv_proj)
505
506
        return_bias: If true, return bias together with outputs in forward pass.
        disable_tp: If true, weights matrix won't be sharded through tp rank.
507
508
    """

509
510
    # --8<-- [end:column_parallel_linear]

511
512
513
514
515
516
517
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
518
519
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
520
521
522
        prefix: str = "",
        *,
        return_bias: bool = True,
523
        disable_tp: bool = False,
524
    ):
525
        # Divide the weight matrix along the last dimension.
526
527
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
528
529
        self.input_size_per_partition = input_size
        self.output_size_per_partition = divide(output_size, self.tp_size)
530
531
532
533
        self.output_partition_sizes = [self.output_size_per_partition]
        # If QKV or MergedColumn, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = [
534
                divide(output_size, self.tp_size) for output_size in self.output_sizes
535
536
            ]

537
538
539
540
541
542
543
544
545
546
        super().__init__(
            input_size,
            output_size,
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
547

548
        self._maybe_allow_fp8_block_shape_mismatch()
549
550
551
        self.gather_output = gather_output

        assert self.quant_method is not None
552
553
        self.quant_method.create_weights(
            layer=self,
554
            input_size_per_partition=self.input_size_per_partition,
555
556
557
558
            output_partition_sizes=self.output_partition_sizes,
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
559
            weight_loader=(
560
561
562
563
564
                self.weight_loader_v2
                if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
                else self.weight_loader
            ),
        )
565
566
        if bias:
            self.bias = Parameter(
567
568
569
570
571
572
573
574
575
                torch.empty(self.output_size_per_partition, dtype=params_dtype)
            )
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
576
577
        else:
            self.register_parameter("bias", None)
578
        self.update_param_tp_status()
579
        self.is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
580

581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
    def _maybe_allow_fp8_block_shape_mismatch(self) -> None:
        quant_config = getattr(self, "quant_config", None)
        weight_block = getattr(quant_config, "weight_block_size", None)
        if (
            weight_block is None
            or len(weight_block) < 1
            or len(self.output_partition_sizes) <= 1
        ):
            return

        try:
            block_n = int(weight_block[0])
        except (ValueError, TypeError):
            return

        if block_n <= 0:
            return

        if any(size % block_n != 0 for size in self.output_partition_sizes):
            self.allow_fp8_block_shape_mismatch = True
            logger.debug(
                "Allowing FP8 block shape mismatch for %s (block_n=%d, partitions=%s)",
                getattr(self, "prefix", "<unknown>"),
                block_n,
                self.output_partition_sizes,
            )

608
609
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        output_dim = getattr(param, "output_dim", None)
610

611
612
613
614
615
616
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

617
618
619
620
621
622
623
624
        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
625
626
            final_shape = list(loaded_weight.shape)
            if output_dim is not None:
627
                assert final_shape[output_dim] % self.tp_size == 0
628
                final_shape[output_dim] = final_shape[output_dim] // self.tp_size
629
            param.materialize(final_shape, dtype=loaded_weight.dtype)
630

631
        param_data = param.data
632
        if output_dim is not None and not is_sharded_weight:
633
            if not envs.VLLM_USE_NN or len(param_data.shape)==1 or self.is_quantization:
634
635
636
                shard_size = param_data.shape[output_dim] 
            else:
                shard_size = param_data.shape[int(not(output_dim))]
637
            start_idx = self.tp_rank * shard_size
638
            loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
639

640
641
642
643
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)
644

645
        if envs.VLLM_USE_NN and not self.is_quantization:
646
647
            loaded_weight = loaded_weight.t()
            
648
649
650
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

651
    def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor):
652
653
654
655
656
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)
657
        param.load_column_parallel_weight(loaded_weight=loaded_weight, is_quantization=self.is_quantization)
658

659
    def forward(
660
661
        self,
        input_,
662
    ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
zhuwenwen's avatar
zhuwenwen committed
663
664
665
666
667
668
669
670
671
        bias = self.bias if not self.skip_bias_add else None

        # Matrix multiply.
        assert self.quant_method is not None
        output_parallel = self.quant_method.apply(self, input_, bias)

        if self.gather_output and self.tp_size > 1:
            # All-gather across the partitions.
            output = tensor_model_parallel_all_gather(output_parallel)
672
        else:
zhuwenwen's avatar
zhuwenwen committed
673
            output = output_parallel
674

zhuwenwen's avatar
zhuwenwen committed
675
676
        if not self.return_bias:
            return output
677
        output_bias = self.bias if self.skip_bias_add else None
zhuwenwen's avatar
zhuwenwen committed
678
        return output, output_bias
679

680
681
682
683
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size_per_partition}"
        s += f", bias={self.bias is not None}"
684
        s += f", tp_size={self.tp_size}"
685
686
687
        s += f", gather_output={self.gather_output}"
        return s

688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706

class MergedColumnParallelLinear(ColumnParallelLinear):
    """Packed linear layers with column parallelism.

    Similar to ColumnParallelLinear, but the weight matrix is concatenated
    along the output dimension. When the weight matrix is loaded, the
    different partitions are sharded separately.

    Args:
        input_size: input dimension of the linear layer.
        output_sizes: list of output dimensions of the linear layer.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make the output
                       available to all GPUs, otherwise, every GPU will have
                       its own output.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
707
        quant_config: Quantization configure.
708
709
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
710
        return_bias: If true, return bias together with outputs in forward pass.
711
712
        disable_tp: If true, all weights matrix won't be sharded, this layer
                    will be treated as a "Replicated" MergedLinear.
713
714
    """

715
716
717
718
719
720
721
    def __init__(
        self,
        input_size: int,
        output_sizes: list[int],
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
722
723
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
zhuwenwen's avatar
zhuwenwen committed
724
        eps: float | None = 1e-6,
725
726
727
        prefix: str = "",
        *,
        return_bias: bool = True,
728
        disable_tp: bool = False,
729
    ):
730
        self.eps = eps
731
        self.output_sizes = output_sizes
732
733
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
734

735
736
737
738
739
740
741
742
743
744
745
746
747
        assert all(output_size % self.tp_size == 0 for output_size in output_sizes)
        super().__init__(
            input_size=input_size,
            output_size=sum(output_sizes),
            bias=bias,
            gather_output=gather_output,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            quant_config=quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
748
        self.is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
James Fleming's avatar
James Fleming committed
749

750
751
752
753
    def weight_loader(
        self,
        param: Parameter,
        loaded_weight: torch.Tensor,
754
        loaded_shard_id: int | None = None,
755
    ):
756
757
758
759
760
        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
761
762
763
764
765
            if loaded_shard_id is not None:
                param.data[loaded_shard_id].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
                param.shard_weight_type = {
766
                    i: loaded_weight.item() for i, _ in enumerate(self.output_sizes)
767
                }
768
769
            return

770
771
        if is_gguf_weight:
            output_dim = getattr(param, "output_dim", None)
772
773
            shard_size = loaded_weight.size(output_dim) // self.tp_size
            start_idx = self.tp_rank * shard_size
774

775
            if loaded_shard_id is not None:
776
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
777
778
779
780
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
781

782
783
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
784
785
        # Special case for per-tensor scale to load scalar into fused array.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
786

787
        if loaded_shard_id is None:
788
789
            # Loaded weight is already fused on disk (mlp).
            # (e.g., Phi-3's gate_up_proj).
790
            if output_dim is None:
791
                if needs_scalar_to_array:
792
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
793
794
                        param_data, loaded_weight, 0
                    )
795

796
797
798
799
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            current_shard_offset = 0
800
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
801
            shard_offsets: list[tuple[int, int, int]] = []
802
803
804
805
806
            for i, output_size in enumerate(self.output_sizes):
                shard_offsets.append((i, current_shard_offset, output_size))
                current_shard_offset += output_size
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
807
                # Special case for Quantization.
808
809
810
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
811
812
                    shard_size = shard_size // param.packed_factor
                    shard_offset = shard_offset // param.packed_factor
813
                    # Special case for Marlin.
814
                    shard_size, shard_offset = adjust_marlin_shard(
815
816
                        param, shard_size, shard_offset
                    )
817

818
                shard_size, shard_offset = adjust_bitblas_shard(
819
820
                    param, shard_size, shard_offset
                )
821

822
                if use_bitsandbytes_4bit:
823
824
825
826
827
828
829
                    index = list(itertools.accumulate([0] + self.output_sizes))
                    orig_offsets = {
                        str(i): (index[i], size)
                        for i, size in enumerate(self.output_sizes)
                    }
                    orig_offsets["total"] = (self.output_size, 0)
                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
830
831
                        param, orig_offsets, str(shard_id)
                    )
832

833
                loaded_weight_shard = loaded_weight.narrow(
834
835
                    output_dim, shard_offset, shard_size
                )
836
837
838
839
840
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id < len(self.output_sizes)
        if output_dim is not None:
841
842
843
844
845
846
847
848
849
850
851
852
            shard_offset = sum(self.output_sizes[:loaded_shard_id])
            shard_size = self.output_sizes[loaded_shard_id]

            if isinstance(param, BlockQuantScaleParameter):
                weight_block_size = getattr(self, "weight_block_size", None)
                shard_size, shard_offset = adjust_block_scale_shard(
                    weight_block_size, shard_size, shard_offset
                )

            shard_offset //= self.tp_size
            shard_size //= self.tp_size

853
            # Special case for quantization.
854
855
856
857
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
858
859
                shard_size = shard_size // param.packed_factor
                shard_offset = shard_offset // param.packed_factor
860
                # Special case for Marlin.
861
                shard_size, shard_offset = adjust_marlin_shard(
862
863
                    param, shard_size, shard_offset
                )
864
            shard_size, shard_offset = adjust_bitblas_shard(
865
866
                param, shard_size, shard_offset
            )
gaoqiong's avatar
gaoqiong committed
867

868
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
869
870
871
872
873
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

874
            if use_bitsandbytes_4bit:
875
                shard_size = loaded_weight.shape[output_dim]
876
                shard_offset = loaded_weight.shape[output_dim] * loaded_shard_id
877

878
            if not envs.VLLM_USE_NN or self.is_quantization or (envs.VLLM_USE_NN and param_data.dim()==1):
879
880
881
                param_data = param_data.narrow(output_dim, shard_offset, shard_size)
            else:
                param_data = param_data.narrow(int(not(output_dim)), shard_offset, shard_size)
882
            start_idx = self.tp_rank * shard_size
883
            if not is_sharded_weight:
884
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
885
886
887
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
888
889
                param_data, loaded_weight, loaded_shard_id
            )
890

891
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
892
893
894
895
896
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "MergedColumnParallelLinear, assume the weight is "
897
898
                    "the same for all partitions."
                )
899

900
        if envs.VLLM_USE_NN and not self.is_quantization:
901
902
            loaded_weight = loaded_weight.t()
            
gaoqiong's avatar
gaoqiong committed
903
904
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)
905

906
907
908
    def _load_fused_module_from_checkpoint(
        self, param: BasevLLMParameter, loaded_weight: torch.Tensor
    ):
909
910
911
        """
        Handle special case for models where MLP layers are already
        fused on disk. In this case, we have no shard id. This function
912
        determines the shard id by splitting these layers and then calls
913
914
915
916
917
918
919
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """

        current_shard_offset = 0
920
        shard_offsets: list[tuple[int, int, int]] = []
921
922
923
924
925
926
927
928
        for i, output_size in enumerate(self.output_sizes):
            shard_offsets.append((i, current_shard_offset, output_size))
            current_shard_offset += output_size

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
929
930
931
932
933
934
935
936
937
938
939
            if (
                isinstance(param, (PackedColumnParameter, PackedvLLMParameter))
                and param.packed_dim == param.output_dim
            ):
                shard_size, shard_offset = param.adjust_shard_indexes_for_packing(
                    shard_size=shard_size, shard_offset=shard_offset
                )

            loaded_weight_shard = loaded_weight.narrow(
                param.output_dim, shard_offset, shard_size
            )
940
941
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

942
943
944
945
    def weight_loader_v2(
        self,
        param: BasevLLMParameter,
        loaded_weight: torch.Tensor,
946
        loaded_shard_id: int | None = None,
947
    ):
948
        if loaded_shard_id is None:
949
            if isinstance(param, PerTensorScaleParameter):
950
                param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0)
951
                return
952
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
953
                param.load_merged_column_weight(loaded_weight=loaded_weight)
954
                return
955
            # TODO: @dsikka - move to parameter.py
956
957
958
959
960
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id < len(self.output_sizes)

961
962
963
        shard_offset = sum(self.output_sizes[:loaded_shard_id])
        shard_size = self.output_sizes[loaded_shard_id]

964
        if isinstance(param, BlockQuantScaleParameter):
965
966
967
            weight_block_size = getattr(self, "weight_block_size", None)
            shard_size, shard_offset = adjust_block_scale_shard(
                weight_block_size, shard_size, shard_offset
968
            )
969
970
971

        shard_offset //= self.tp_size
        shard_size //= self.tp_size
972

973
974
975
976
977
978
        param.load_merged_column_weight(
            loaded_weight=loaded_weight,
            shard_id=loaded_shard_id,
            shard_offset=shard_offset,
            shard_size=shard_size,
            tp_rank=self.tp_rank,
979
            is_quantization=self.is_quantization
980
        )
981

982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003

class QKVParallelLinear(ColumnParallelLinear):
    """Linear layers for the attention's QKV transformation.

    Linear layers for the linear transformation of the query, key, and value
    vectors in the attention layer. The weight matrix is concatenated along
    the output dimension. The layer is parallelized along the head dimension.
    When the number of key/value heads is smaller than the number of query
    heads (e.g., multi-query/grouped-query attention), the key/value head may
    be replicated while the query heads are partitioned.

    Args:
        hidden_size: input hidden state size of the transformer.
        head_size: size of each attention head.
        total_num_heads: total number of attention query heads.
        total_num_kv_heads: total number of attention key/value heads. If
                            None, assume total_num_kv_heads = total_num_heads.
        bias: If true, add bias.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
1004
        quant_config: Quantization configure.
1005
1006
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
1007
        return_bias: If true, return bias together with outputs in forward pass.
1008
        disable_tp: If true, weights matrix won't be sharded through tp rank.
1009
1010
    """

1011
1012
1013
1014
1015
    def __init__(
        self,
        hidden_size: int,
        head_size: int,
        total_num_heads: int,
1016
        total_num_kv_heads: int | None = None,
1017
1018
        bias: bool = True,
        skip_bias_add: bool = False,
1019
1020
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
1021
1022
1023
        prefix: str = "",
        *,
        return_bias: bool = True,
1024
        disable_tp: bool = False,
1025
        v_head_size: int | None = None,
1026
    ):
1027
1028
        self.hidden_size = hidden_size
        self.head_size = head_size
1029
        self.v_head_size = v_head_size if v_head_size is not None else head_size
1030
1031
1032
1033
1034
        self.total_num_heads = total_num_heads
        if total_num_kv_heads is None:
            total_num_kv_heads = total_num_heads
        self.total_num_kv_heads = total_num_kv_heads
        # Divide the weight matrix along the last dimension.
1035
        tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
1036
1037
1038
        self.num_heads = divide(self.total_num_heads, tp_size)
        if tp_size >= self.total_num_kv_heads:
            self.num_kv_heads = 1
1039
            self.num_kv_head_replicas = divide(tp_size, self.total_num_kv_heads)
1040
1041
1042
1043
        else:
            self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
            self.num_kv_head_replicas = 1
        input_size = self.hidden_size
1044
        output_size = (
1045
1046
1047
1048
            self.num_heads * self.head_size
            + self.num_kv_heads * self.head_size
            + self.num_kv_heads * self.v_head_size
        ) * tp_size
1049
1050
1051
        self.output_sizes = [
            self.num_heads * self.head_size * tp_size,  # q_proj
            self.num_kv_heads * self.head_size * tp_size,  # k_proj
1052
            self.num_kv_heads * self.v_head_size * tp_size,  # v_proj
James Fleming's avatar
James Fleming committed
1053
        ]
gaoqiong's avatar
gaoqiong committed
1054

1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
        super().__init__(
            input_size=input_size,
            output_size=output_size,
            bias=bias,
            gather_output=False,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            quant_config=quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
1067
        self.is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
1068

1069
1070
1071
1072
1073
    def _get_shard_offset_mapping(self, loaded_shard_id: str):
        shard_offset_mapping = {
            "q": 0,
            "k": self.num_heads * self.head_size,
            "v": (self.num_heads + self.num_kv_heads) * self.head_size,
1074
1075
            "total": (self.num_heads + self.num_kv_heads) * self.head_size
            + self.num_kv_heads * self.v_head_size,
1076
1077
1078
1079
1080
1081
1082
        }
        return shard_offset_mapping.get(loaded_shard_id)

    def _get_shard_size_mapping(self, loaded_shard_id: str):
        shard_size_mapping = {
            "q": self.num_heads * self.head_size,
            "k": self.num_kv_heads * self.head_size,
1083
            "v": self.num_kv_heads * self.v_head_size,
1084
1085
1086
        }
        return shard_size_mapping.get(loaded_shard_id)

1087
1088
1089
    def _load_fused_module_from_checkpoint(
        self, param: BasevLLMParameter, loaded_weight: torch.Tensor
    ):
1090
        """
1091
        Handle special case for models where QKV layers are already
1092
        fused on disk. In this case, we have no shard id. This function
1093
        determines the shard id by splitting these layers and then calls
1094
1095
1096
1097
1098
1099
1100
1101
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """
        shard_offsets = [
            # (shard_id, shard_offset, shard_size)
            ("q", 0, self.total_num_heads * self.head_size),
1102
1103
1104
1105
1106
1107
1108
1109
            (
                "k",
                self.total_num_heads * self.head_size,
                self.total_num_kv_heads * self.head_size,
            ),
            (
                "v",
                (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
1110
                self.total_num_kv_heads * self.v_head_size,
1111
            ),
1112
1113
1114
1115
1116
1117
        ]

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
            if (
                isinstance(param, (PackedColumnParameter, PackedvLLMParameter))
                and param.packed_dim == param.output_dim
            ):
                shard_size, shard_offset = param.adjust_shard_indexes_for_packing(
                    shard_size=shard_size, shard_offset=shard_offset
                )

            loaded_weight_shard = loaded_weight.narrow(
                param.output_dim, shard_offset, shard_size
            )
1129
1130
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

1131
1132
1133
1134
    def weight_loader_v2(
        self,
        param: BasevLLMParameter,
        loaded_weight: torch.Tensor,
1135
        loaded_shard_id: str | None = None,
1136
    ):
1137
        if loaded_shard_id is None:  # special case for certain models
1138
            if isinstance(param, PerTensorScaleParameter):
1139
1140
1141
                param.load_qkv_weight(
                    loaded_weight=loaded_weight, shard_id=0, tp_rank=self.tp_rank
                )
1142
                return
1143
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
1144
                param.load_qkv_weight(loaded_weight=loaded_weight, tp_rank=self.tp_rank)
1145
                return
1146
            # TODO: @dsikka - move to parameter.py
1147
1148
1149
1150
1151
1152
1153
1154
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id in ["q", "k", "v"]

        shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
        shard_size = self._get_shard_size_mapping(loaded_shard_id)

1155
        if isinstance(param, BlockQuantScaleParameter):
1156
1157
1158
1159
            weight_block_size = getattr(self, "weight_block_size", None)
            shard_size, shard_offset = adjust_block_scale_shard(
                weight_block_size, shard_size, shard_offset
            )
1160

1161
1162
1163
1164
1165
1166
1167
        param.load_qkv_weight(
            loaded_weight=loaded_weight,
            num_heads=self.num_kv_head_replicas,
            shard_id=loaded_shard_id,
            shard_offset=shard_offset,
            shard_size=shard_size,
            tp_rank=self.tp_rank,
1168
            is_quantization=self.is_quantization, 
1169
        )
1170

1171
1172
1173
1174
    def weight_loader(
        self,
        param: Parameter,
        loaded_weight: torch.Tensor,
1175
        loaded_shard_id: str | None = None,
1176
    ):
1177
1178
1179
1180
        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
1181
        if is_gguf_weight_type:
1182
            idx_map = {"q": 0, "k": 1, "v": 2}
1183
1184
1185
1186
            if loaded_shard_id is not None:
                param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
1187
                param.shard_weight_type = {k: loaded_weight.item() for k in idx_map}
1188
1189
            return

1190
1191
        if is_gguf_weight:
            output_dim = getattr(param, "output_dim", None)
1192
1193
            shard_size = loaded_weight.size(output_dim) // self.tp_size
            start_idx = self.tp_rank * shard_size
1194

1195
            if loaded_shard_id is not None:
1196
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
1197
1198
1199
1200
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
1201

1202
1203
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
1204

1205
1206
        # Special case for per-tensor scales in fused case.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
1207

1208
        if loaded_shard_id is None:
1209
1210
            # Loaded weight is already fused on disk (qkv).
            # (e.g., Phi-3's qkv_proj).
1211
            if output_dim is None:
1212
                if needs_scalar_to_array:
1213
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
1214
1215
                        param_data, loaded_weight, 0
                    )
1216

1217
1218
1219
1220
1221
1222
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            shard_offsets = [
                # (shard_id, shard_offset, shard_size)
                ("q", 0, self.total_num_heads * self.head_size),
1223
1224
1225
1226
1227
1228
1229
1230
                (
                    "k",
                    self.total_num_heads * self.head_size,
                    self.total_num_kv_heads * self.head_size,
                ),
                (
                    "v",
                    (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
1231
                    self.total_num_kv_heads * self.v_head_size,
1232
                ),
1233
            ]
1234
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1235

1236
1237
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
1238
                # Special case for Quantized Weights.
1239
1240
1241
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
1242
1243
                    shard_size = shard_size // param.packed_factor
                    shard_offset = shard_offset // param.packed_factor
1244

1245
                    # Special case for Marlin.
1246
                    shard_size, shard_offset = adjust_marlin_shard(
1247
1248
                        param, shard_size, shard_offset
                    )
1249

1250
1251
1252
                if use_bitsandbytes_4bit:
                    orig_qkv_offsets = {
                        "q": (0, self.total_num_heads * self.head_size),
1253
1254
1255
1256
1257
1258
1259
                        "k": (
                            self.total_num_heads * self.head_size,
                            self.total_num_kv_heads * self.head_size,
                        ),
                        "v": (
                            (self.total_num_heads + self.total_num_kv_heads)
                            * self.head_size,
1260
                            self.total_num_kv_heads * self.v_head_size,
1261
1262
                        ),
                        "total": (
1263
1264
1265
                            (self.total_num_heads + self.total_num_kv_heads)
                            * self.head_size
                            + self.total_num_kv_heads * self.v_head_size,
1266
1267
                            0,
                        ),
1268
1269
1270
                    }

                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
1271
1272
                        param, orig_qkv_offsets, shard_id
                    )
1273

1274
                loaded_weight_shard = loaded_weight.narrow(
1275
1276
                    output_dim, shard_offset, shard_size
                )
1277
1278
1279
1280
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id in ["q", "k", "v"]
1281
1282

        # If output dim is defined, use the default loading process.
1283
1284
1285
1286
1287
1288
1289
1290
        if output_dim is not None:
            if loaded_shard_id == "q":
                shard_offset = 0
                shard_size = self.num_heads * self.head_size
            elif loaded_shard_id == "k":
                shard_offset = self.num_heads * self.head_size
                shard_size = self.num_kv_heads * self.head_size
            elif loaded_shard_id == "v":
1291
                shard_offset = (self.num_heads + self.num_kv_heads) * self.head_size
1292
                shard_size = self.num_kv_heads * self.v_head_size
1293
1294
1295
1296
1297
1298
1299

            if isinstance(param, BlockQuantScaleParameter):
                weight_block_size = getattr(self, "weight_block_size", None)
                shard_size, shard_offset = adjust_block_scale_shard(
                    weight_block_size, shard_size, shard_offset
                )

1300
            # Special case for Quantized Weights.
1301
1302
1303
1304
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
1305
1306
                shard_size = shard_size // param.packed_factor
                shard_offset = shard_offset // param.packed_factor
1307

1308
                # Special case for Marlin.
1309
                shard_size, shard_offset = adjust_marlin_shard(
1310
1311
                    param, shard_size, shard_offset
                )
1312

1313
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1314
1315
1316
1317
1318
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

1319
            if use_bitsandbytes_4bit:
1320
1321
                orig_qkv_offsets = {
                    "q": (0, self.num_heads * self.head_size),
1322
1323
1324
1325
1326
1327
                    "k": (
                        self.num_heads * self.head_size,
                        self.num_kv_heads * self.head_size,
                    ),
                    "v": (
                        (self.num_heads + self.num_kv_heads) * self.head_size,
1328
                        self.num_kv_heads * self.v_head_size,
1329
1330
                    ),
                    "total": (
1331
1332
                        (self.num_heads + self.num_kv_heads) * self.head_size
                        + self.num_kv_heads * self.v_head_size,
1333
1334
                        0,
                    ),
1335
                }
1336
                shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
1337
1338
                    param, orig_qkv_offsets, loaded_shard_id
                )
gaoqiong's avatar
gaoqiong committed
1339

1340
            if not envs.VLLM_USE_NN or len(param_data.shape)==1 or self.is_quantization:
1341
                param_data = param_data.narrow(output_dim, shard_offset, shard_size)
1342
1343
1344
            else:
                param_data = param_data.narrow(int(not(output_dim)), shard_offset,
                                               shard_size)
zhuwenwen's avatar
zhuwenwen committed
1345
            if loaded_shard_id == "q":
1346
                shard_rank = self.tp_rank
1347
            else:
1348
1349
                shard_rank = self.tp_rank // self.num_kv_head_replicas
            start_idx = shard_rank * shard_size
1350

1351
            if not is_sharded_weight:
1352
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
1353

1354
1355
1356
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
1357
1358
                param_data, loaded_weight, loaded_shard_id
            )
1359
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
1360
1361
1362
1363
1364
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "QKVParallelLinear, assume the weight is the same "
1365
1366
                    "for all partitions."
                )
gaoqiong's avatar
gaoqiong committed
1367

1368
        if envs.VLLM_USE_NN and not self.is_quantization:
1369
1370
            loaded_weight = loaded_weight.t()
            
gaoqiong's avatar
gaoqiong committed
1371
1372
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)
1373
1374


1375
# --8<-- [start:row_parallel_linear]
1376
@CustomOp.register("row_parallel_linear")
1377
class RowParallelLinear(LinearBase):
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
    """Linear layer with row parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its first dimension and X along its second dimension as:
               -   -
              | A_1 |
              | .   |
          A = | .   |        X = [X_1, ..., X_p]
              | .   |
              | A_p |
               -   -
    Arguments:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias. Note that bias is not parallelized.
        input_is_parallel: If true, we assume that the input is already
                           split across the GPUs and we do not split
                           again.
        skip_bias_add: This was added to enable performance optimization where
                       bias can be fused with other element-wise operations.
                       We skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
1400
1401
1402
        reduce_results: If true, call all-reduce on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y = X_iA_i
1403
        quant_config: Quantization configure.
1404
1405
1406
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.down_proj)
        return_bias: If true, return bias together with outputs in forward pass.
1407
        disable_tp: If true, weights matrix won't be sharded through tp rank.
1408
1409
    """

1410
1411
    # --8<-- [end:row_parallel_linear]

1412
1413
1414
1415
1416
1417
1418
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        input_is_parallel: bool = True,
        skip_bias_add: bool = False,
1419
        params_dtype: torch.dtype | None = None,
1420
        reduce_results: bool = True,
1421
        quant_config: QuantizationConfig | None = None,
1422
1423
1424
        prefix: str = "",
        *,
        return_bias: bool = True,
1425
        disable_tp: bool = False,
1426
    ):
1427
        # Divide the weight matrix along the first dimension.
1428
1429
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
1430
1431
1432
1433
        self.input_size_per_partition = divide(input_size, self.tp_size)
        self.output_size_per_partition = output_size
        self.output_partition_sizes = [output_size]

1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
        super().__init__(
            input_size,
            output_size,
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
1444

1445
1446
1447
        self.input_is_parallel = input_is_parallel
        self.reduce_results = reduce_results

1448
        assert self.quant_method is not None
1449
1450
1451
        self.quant_method.create_weights(
            layer=self,
            input_size_per_partition=self.input_size_per_partition,
1452
            output_partition_sizes=self.output_partition_sizes,
1453
1454
1455
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
1456
            weight_loader=(
1457
1458
1459
1460
1461
                self.weight_loader_v2
                if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
                else self.weight_loader
            ),
        )
1462
        if not reduce_results and (bias and not skip_bias_add):
1463
1464
1465
1466
            raise ValueError(
                "When not reduce the results, adding bias to the "
                "results can lead to incorrect results"
            )
1467
1468

        if bias:
1469
1470
1471
1472
1473
1474
1475
1476
            self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype))
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
1477
1478
        else:
            self.register_parameter("bias", None)
1479

1480
        self.update_param_tp_status()
1481
        self.is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
1482
1483
1484

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        input_dim = getattr(param, "input_dim", None)
1485
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1486
1487
1488
1489
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
1490
1491
1492
1493
1494
1495
1496
1497
1498

        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
1499
1500
            weight_shape = list(loaded_weight.shape)
            if input_dim:
1501
                weight_shape[input_dim] = weight_shape[input_dim] // self.tp_size
1502
            param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
1503

1504
        param_data = param.data
1505
        if input_dim is not None and not is_sharded_weight:
1506
            if not envs.VLLM_USE_NN or self.is_quantization:
1507
1508
1509
                shard_size = param_data.shape[input_dim]
            else:
                shard_size = param_data.shape[int(not(input_dim))]
1510
            start_idx = self.tp_rank * shard_size
1511
            loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)
1512

1513
1514
1515
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
1516
1517
            loaded_weight = loaded_weight.reshape(1)

1518
        if envs.VLLM_USE_NN and not self.is_quantization:
1519
1520
            loaded_weight = loaded_weight.t()
            
1521
1522
1523
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

1524
    def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor):
1525
1526
1527
1528
1529
1530
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)

1531
        param.load_row_parallel_weight(loaded_weight=loaded_weight, is_quantization=self.is_quantization)
1532

1533
    def forward(
1534
1535
        self,
        input_,
1536
    ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
1537
1538
1539
1540
        if self.input_is_parallel:
            input_parallel = input_
        else:
            splitted_input = split_tensor_along_last_dim(
1541
1542
                input_, num_partitions=self.tp_size
            )
1543
            input_parallel = splitted_input[self.tp_rank].contiguous()
1544
1545

        # Matrix multiply.
1546
        assert self.quant_method is not None
1547
1548
1549
        # Only fuse bias add into GEMM for rank 0 (this ensures that
        # bias will not get added more than once in TP>1 case)
        bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
1550
        output_parallel = self.quant_method.apply(self, input_parallel, bias_)
1551

1552
        if self.reduce_results and self.tp_size > 1:
zhuwenwen's avatar
zhuwenwen committed
1553
            output = tensor_model_parallel_all_reduce(output_parallel)
1554
        else:
1555
1556
            output = output_parallel

1557
1558
        if not self.return_bias:
            return output
1559
        output_bias = self.bias if self.skip_bias_add else None
1560
        return output, output_bias
1561
1562

    def extra_repr(self) -> str:
1563
        s = f"in_features={self.input_size_per_partition}"
1564
1565
1566
1567
1568
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={self.tp_size}"
        s += f", reduce_results={self.reduce_results}"
        return s