linear.py 61.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
from abc import abstractmethod
6

7
from typing import Any
8
from vllm import envs
9
10

import torch
11
from torch.nn.parameter import Parameter, UninitializedParameter
12

13
14
15
16
17
18
19
20
from vllm.distributed import (
    divide,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
    split_tensor_along_last_dim,
    tensor_model_parallel_all_gather,
    tensor_model_parallel_all_reduce,
)
21
from vllm.logger import init_logger
22
from vllm.model_executor.custom_op import CustomOp
23
from vllm.model_executor.layers.quantization.base_config import (
24
25
26
    QuantizationConfig,
    QuantizeMethodBase,
)
27
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
28
29
30
31
32
33
34
35
36
from vllm.model_executor.parameter import (
    BasevLLMParameter,
    BlockQuantScaleParameter,
    ModelWeightParameter,
    PackedColumnParameter,
    PackedvLLMParameter,
    PerTensorScaleParameter,
    RowvLLMParameter,
)
37
from vllm.model_executor.utils import set_weight_attrs
38
from vllm.platforms import current_platform
gaoqiong's avatar
gaoqiong committed
39

zhuwenwen's avatar
zhuwenwen committed
40
import os
41
from vllm.model_executor.utils import gemm_bank_conf
42
43
44

logger = init_logger(__name__)

45
WEIGHT_LOADER_V2_SUPPORTED = [
46
    "UnquantizedLinearMethod",
47
    "CompressedTensorsLinearMethod",
48
    "CompressedTensorsLinearTransformMethod",
49
50
    "BitBLASLinearMethod",
    "GPTQBitBLASLinearMethod",
51
52
53
54
55
56
57
58
59
60
    "AWQMarlinLinearMethod",
    "AWQLinearMethod",
    "GPTQMarlinLinearMethod",
    "Fp8LinearMethod",
    "MarlinLinearMethod",
    "GPTQMarlin24LinearMethod",
    "TPUInt8LinearMethod",
    "GPTQLinearMethod",
    "FBGEMMFp8LinearMethod",
    "ModelOptFp8LinearMethod",
61
62
    "ModelOptFp8PcPtLinearMethod",
    "ModelOptFp8PbWoLinearMethod",
63
64
65
66
    "IPEXAWQLinearMethod",
    "IPEXGPTQLinearMethod",
    "QuarkLinearMethod",
    "ModelOptNvFp4LinearMethod",
67
    "PetitNvFp4LinearMethod",
zhuwenwen's avatar
zhuwenwen committed
68
    "BlockInt8LinearMethod",
69
]
70

71

72
73
74
def adjust_bitblas_shard(param, shard_size, shard_offset):
    bitblas_tile_size = getattr(param, "bitblas_tile_size", None)
    if bitblas_tile_size is not None:
75
        return (shard_size // bitblas_tile_size, shard_offset // bitblas_tile_size)
76
77
78
79

    return shard_size, shard_offset


80
81
82
83
84
85
86
87
def adjust_marlin_shard(param, shard_size, shard_offset):
    marlin_tile_size = getattr(param, "marlin_tile_size", None)
    if marlin_tile_size is None:
        return shard_size, shard_offset

    return shard_size * marlin_tile_size, shard_offset * marlin_tile_size


88
89
90
91
92
93
94
95
def adjust_block_scale_shard(weight_block_size, shard_size, shard_offset):
    assert weight_block_size is not None
    block_n = weight_block_size[0]
    shard_offset = (shard_offset + block_n - 1) // block_n
    shard_size = (shard_size + block_n - 1) // block_n
    return shard_size, shard_offset


96
97
98
def adjust_bitsandbytes_4bit_shard(
    param: Parameter, shard_offsets: dict[str, tuple[int, int]], loaded_shard_id: str
) -> tuple[int, int]:
99
100
    """Adjust the quantization offsets and sizes for BitsAndBytes sharding."""

101
102
    total, _ = shard_offsets["total"]
    orig_offset, orig_size = shard_offsets[loaded_shard_id]
103
104
105
106
107
108
109
110

    quantized_total = param.data.shape[0]
    quantized_offset = orig_offset * quantized_total // total
    quantized_size = orig_size * quantized_total // total

    return quantized_size, quantized_offset


111
112
113
def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
    """For fused modules (QKV and MLP) we have an array of length
    N that holds 1 scale for each "logical" matrix. So the param
114
115
    is an array of length N. The loaded_weight corresponds to
    one of the shards on disk. Here, we slice the param based on
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
    the shard_id for loading.
    """
    qkv_idxs = {"q": 0, "k": 1, "v": 2}

    if isinstance(shard_id, str):
        shard_id = qkv_idxs[shard_id]
    elif not isinstance(shard_id, int):
        raise ValueError(f"Unknown Shard Id {shard_id}")

    # AutoFP8 scales do not have a shape
    # compressed-tensors scales do have a shape
    if len(loaded_weight.shape) != 0:
        assert loaded_weight.shape[0] == 1
        loaded_weight = loaded_weight[0]

131
132
133
134
    if envs.VLLM_USE_NN:
        return param[shard_id], loaded_weight.t()
    else:
        return param[shard_id], loaded_weight
135
136


137
138
139
140
141
142
143
144
# TODO(Isotr0py): We might need a more flexible structure to handle
# bitsandbytes shard offsets.
def left_shift_bitsandbytes_4bit_shard(bnb_weight_attrs: dict[str, Any]):
    """
    Separate the BitsAndBytes 4-bit shard.

    For example, given bnb weight attributes as below:
    {
145
        'bnb_shard_offsets': array([0, 4, 8, 16]),
146
147
148
149
150
        'bnb_quant_state': {0: ..., 1: ..., 2: ...},
    }

    The function will return:
    {
151
        'bnb_shard_offsets': array([0, 4]),
152
153
154
155
156
157
158
159
160
161
162
163
164
165
        'bnb_quant_state': {0: ...},
    }
    and
    {
        'bnb_shard_offsets': array([0, 4, 12]),
        'bnb_quant_state': {0: ..., 1: ...},
    }
    """
    shard_offsets = bnb_weight_attrs["bnb_shard_offsets"]
    offset_l = shard_offsets[:2]
    offset_r = shard_offsets[1:] - shard_offsets[1]
    quant_state_l = {0: bnb_weight_attrs["bnb_quant_state"][0]}
    quant_state_r = {
        i - 1: bnb_weight_attrs["bnb_quant_state"][i]
166
        for i in range(1, len(shard_offsets) - 1)
167
168
169
170
171
172
    }
    left = dict(bnb_shard_offsets=offset_l, bnb_quant_state=quant_state_l)
    right = dict(bnb_shard_offsets=offset_r, bnb_quant_state=quant_state_r)
    return left, right


173
class LinearMethodBase(QuantizeMethodBase):
174
175
176
    """Base class for different (maybe quantized) linear methods."""

    @abstractmethod
177
178
179
180
181
182
183
184
185
186
187
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        """Create weights for a linear layer.
188
           The weights will be set as attributes of the layer.
189

190
191
192
        Args:
            layer: The layer that is using the LinearMethodBase factory.
            input_size_per_partition: Size of the weight input dim on rank X.
193
            output_partition_sizes: Sizes of the output dim of each logical
194
195
196
197
198
199
                weight on rank X. E.g., output_partition_sizes for QKVLinear
                is a list contains the width of Wq, Wk, Wv on rank X.
            input_size: Size of the input dim of the weight across all ranks.
            output_size: Size of the output dim of the weight across all ranks.
            params_dtype: Datatype of the parameters.
        """
200
201
202
        raise NotImplementedError

    @abstractmethod
203
204
205
206
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
207
        bias: torch.Tensor | None = None,
208
    ) -> torch.Tensor:
209
210
        """Apply the weights in layer to the input tensor.
        Expects create_weights to have been called before on the layer."""
211
212
213
214
        raise NotImplementedError


class UnquantizedLinearMethod(LinearMethodBase):
215
    """Linear method without quantization."""
216
217
    
    def __init__(self):
zhuwenwen's avatar
zhuwenwen committed
218
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
219
        self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
gaoqiong's avatar
gaoqiong committed
220
        
221

222

223
224
225
226
227
228
229
230
231
232
    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
233
234
235
236
        # This method creates unquantized linear weights.
        # The weights are not quantized, and they are not sharded.
        # The amount of memory allocated for the weights is
        # sum(output_partition_sizes) * input_size_per_partition.
237
        weight_loader = extra_weight_attrs.pop("weight_loader")
238
239
240
        if envs.VLLM_USE_NN:
            weight = ModelWeightParameter(
                data=torch.empty(
241
242
                    input_size_per_partition,
                    sum(output_partition_sizes),
243
244
245
246
247
248
249
250
251
                    dtype=params_dtype,
                ),
                input_dim=1,
                output_dim=0,
                weight_loader=weight_loader,
            )
        else:
            weight = ModelWeightParameter(
                data=torch.empty(
252
253
                    sum(output_partition_sizes),
                    input_size_per_partition,
254
255
256
257
258
259
                    dtype=params_dtype,
                ),
                input_dim=1,
                output_dim=0,
                weight_loader=weight_loader,
            )
260

261
262
        layer.register_parameter("weight", weight)
        set_weight_attrs(weight, extra_weight_attrs)
263

264
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
265
        if current_platform.is_cpu():
266
            from vllm.model_executor.layers.utils import dispatch_cpu_unquantized_gemm
267

268
            dispatch_cpu_unquantized_gemm(layer, remove_weight=True)
269

270
271
272
273
    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
274
        bias: torch.Tensor | None = None,
275
    ) -> torch.Tensor:
zhuwenwen's avatar
zhuwenwen committed
276
        if self.use_llama_nn:
zhuwenwen's avatar
zhuwenwen committed
277
278
            # if os.environ['GEMM_PAD'] == '1' and gemm_bank_conf(layer.weight.shape[1] - 32):
            #     layer.weight = layer.weight[:,:-32]
zhuwenwen's avatar
zhuwenwen committed
279
            if bias is not None:
zhuwenwen's avatar
zhuwenwen committed
280
                if len(x.shape) == 2: 
281
                    return torch.addmm(bias, x, layer.weight)
zhuwenwen's avatar
zhuwenwen committed
282
                else:
283
                    return torch.matmul(x, layer.weight) + bias
zhuwenwen's avatar
zhuwenwen committed
284
            else:
285
                return torch.matmul(x, layer.weight)
zhuwenwen's avatar
zhuwenwen committed
286
        else:
zhuwenwen's avatar
zhuwenwen committed
287
288
289
290
291
292
293
294
295
296
            # if envs.VLLM_USE_NN and x.shape[-1] == layer.weight.shape[0]:
            #     return dispatch_unquantized_gemm()(layer, x, layer.weight.t(), bias)
            if envs.VLLM_USE_NN:
                if bias is not None:
                    if len(x.shape) == 2: 
                        return torch.addmm(bias, x, layer.weight)
                    else:
                        return torch.matmul(x, layer.weight) + bias
                else:
                    return torch.matmul(x, layer.weight)
297
            else:
zhuwenwen's avatar
zhuwenwen committed
298
                return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
299

300

301
class LinearBase(CustomOp):
302
    """Base linear layer.
303
304
305
306
307
308

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
309
        quant_config: Quantization configure.
310
        prefix: Prefix for parameter names.
311
        return_bias: If true, return bias together with outputs in forward pass.
312
        disable_tp: If true, tensor parallelism will be disabled for this layer.
313
314
315
316
317
318
319
    """

    def __init__(
        self,
        input_size: int,
        output_size: int,
        skip_bias_add: bool = False,
320
321
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
322
        prefix: str = "",
323
324
        *,
        return_bias: bool = True,
325
        disable_tp: bool = False,
326
327
328
329
330
331
332
333
334
335
    ):
        super().__init__()

        # Keep input parameters
        self.input_size = input_size
        self.output_size = output_size
        self.skip_bias_add = skip_bias_add
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype
336
337
        self.quant_config = quant_config
        self.prefix = prefix
338
        self.allow_fp8_block_shape_mismatch = False
339
        if quant_config is None:
340
            self.quant_method: QuantizeMethodBase | None = UnquantizedLinearMethod()
341
        else:
342
            self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
343
        self.return_bias = return_bias
344
        self.disable_tp = disable_tp
345
346
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
347

348
    def update_param_tp_status(self):
349
350
351
352
        for param in self.parameters():
            if isinstance(param, BasevLLMParameter):
                param.tp_rank = self.tp_rank
                param.tp_size = self.tp_size
353
354


355
# --8<-- [start:replicated_linear]
356
@CustomOp.register("replicated_linear")
357
358
359
360
361
362
363
364
365
366
class ReplicatedLinear(LinearBase):
    """Replicated linear layer.

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        bias: If true, add bias.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
        quant_config: Quantization configure.
367
368
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
369
        return_bias: If true, return bias together with outputs in forward pass.
370
        disable_tp: Take no effect for replicated linear layers.
371
372
    """

373
374
    # --8<-- [end:replicated_linear]

375
376
377
378
379
380
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        skip_bias_add: bool = False,
381
382
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
383
384
385
        prefix: str = "",
        *,
        return_bias: bool = True,
386
        disable_tp: bool = False,
387
    ):
388
389
390
391
392
        # If MergedReplicatedLinear, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = self.output_sizes
        else:
            self.output_partition_sizes = [output_size]
393

394
395
396
397
398
399
400
401
402
403
        super().__init__(
            input_size,
            output_size,
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
404

405
406
        # All the linear layer supports quant method.
        assert self.quant_method is not None
407
408
409
410
411
412
413
414
415
        self.quant_method.create_weights(
            self,
            self.input_size,
            self.output_partition_sizes,
            self.input_size,
            self.output_size,
            self.params_dtype,
            weight_loader=self.weight_loader,
        )
416

417
418
        if bias:
            self.bias = Parameter(
419
420
421
422
423
424
425
426
427
                torch.empty(self.output_size, dtype=self.params_dtype)
            )
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
428
429
        else:
            self.register_parameter("bias", None)
430
        self.is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
431

432
433
434
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        # If the weight on disk does not have a shape, give it one
        # (such scales for AutoFp8).
435
436
437
438
439
440
441
442
443
444
445
        # Special case for GGUF

        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
            param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)

446
447
448
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)

449
        if envs.VLLM_USE_NN and not self.is_quantization:
450
451
            loaded_weight = loaded_weight.t()
            
452
453
        assert param.size() == loaded_weight.size(), (
            f"Tried to load weights of size {loaded_weight.size()}"
454
455
            f"to a parameter of size {param.size()}"
        )
456
457
        param.data.copy_(loaded_weight)

458
    def forward(
459
        self,
zhuwenwen's avatar
zhuwenwen committed
460
        x: torch.Tensor,
461
    ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
zhuwenwen's avatar
zhuwenwen committed
462
463
        bias = self.bias if not self.skip_bias_add else None
        assert self.quant_method is not None
464

zhuwenwen's avatar
zhuwenwen committed
465
        output = self.quant_method.apply(self, x, bias)
466

zhuwenwen's avatar
zhuwenwen committed
467
468
        if not self.return_bias:
            return output
469
        output_bias = self.bias if self.skip_bias_add else None
zhuwenwen's avatar
zhuwenwen committed
470
        return output, output_bias
471

472
473
474
475
476
477
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        return s

478

479
# --8<-- [start:column_parallel_linear]
480
@CustomOp.register("column_parallel_linear")
481
class ColumnParallelLinear(LinearBase):
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
    """Linear layer with column parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its second dimension as A = [A_1, ..., A_p].

    Args:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y_i = XA_i
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
498
        quant_config: Quantization configure.
499
        prefix: The name of the layer in the state dict, including all parents
500
                        (e.g. model.layers.0.qkv_proj)
501
502
        return_bias: If true, return bias together with outputs in forward pass.
        disable_tp: If true, weights matrix won't be sharded through tp rank.
503
504
    """

505
506
    # --8<-- [end:column_parallel_linear]

507
508
509
510
511
512
513
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
514
515
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
516
517
518
        prefix: str = "",
        *,
        return_bias: bool = True,
519
        disable_tp: bool = False,
520
    ):
521
        # Divide the weight matrix along the last dimension.
522
523
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
524
525
        self.input_size_per_partition = input_size
        self.output_size_per_partition = divide(output_size, self.tp_size)
526
527
528
529
        self.output_partition_sizes = [self.output_size_per_partition]
        # If QKV or MergedColumn, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = [
530
                divide(output_size, self.tp_size) for output_size in self.output_sizes
531
532
            ]

533
534
535
536
537
538
539
540
541
542
        super().__init__(
            input_size,
            output_size,
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
543

544
        self._maybe_allow_fp8_block_shape_mismatch()
545
546
547
        self.gather_output = gather_output

        assert self.quant_method is not None
548
549
        self.quant_method.create_weights(
            layer=self,
550
            input_size_per_partition=self.input_size_per_partition,
551
552
553
554
            output_partition_sizes=self.output_partition_sizes,
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
555
            weight_loader=(
556
557
558
559
560
                self.weight_loader_v2
                if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
                else self.weight_loader
            ),
        )
561
562
        if bias:
            self.bias = Parameter(
563
564
565
566
567
568
569
570
571
                torch.empty(self.output_size_per_partition, dtype=params_dtype)
            )
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
572
573
        else:
            self.register_parameter("bias", None)
574
        self.update_param_tp_status()
575
        self.is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
576

577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
    def _maybe_allow_fp8_block_shape_mismatch(self) -> None:
        quant_config = getattr(self, "quant_config", None)
        weight_block = getattr(quant_config, "weight_block_size", None)
        if (
            weight_block is None
            or len(weight_block) < 1
            or len(self.output_partition_sizes) <= 1
        ):
            return

        try:
            block_n = int(weight_block[0])
        except (ValueError, TypeError):
            return

        if block_n <= 0:
            return

        if any(size % block_n != 0 for size in self.output_partition_sizes):
            self.allow_fp8_block_shape_mismatch = True
            logger.debug(
                "Allowing FP8 block shape mismatch for %s (block_n=%d, partitions=%s)",
                getattr(self, "prefix", "<unknown>"),
                block_n,
                self.output_partition_sizes,
            )

604
605
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        output_dim = getattr(param, "output_dim", None)
606

607
608
609
610
611
612
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

613
614
615
616
617
618
619
620
        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
621
622
            final_shape = list(loaded_weight.shape)
            if output_dim is not None:
623
                assert final_shape[output_dim] % self.tp_size == 0
624
                final_shape[output_dim] = final_shape[output_dim] // self.tp_size
625
            param.materialize(final_shape, dtype=loaded_weight.dtype)
626

627
        param_data = param.data
628
        if output_dim is not None and not is_sharded_weight:
629
            if not envs.VLLM_USE_NN or len(param_data.shape)==1 or self.is_quantization:
630
631
632
                shard_size = param_data.shape[output_dim] 
            else:
                shard_size = param_data.shape[int(not(output_dim))]
633
            start_idx = self.tp_rank * shard_size
634
            loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
635

636
637
638
639
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)
640

641
        if envs.VLLM_USE_NN and not self.is_quantization:
642
            loaded_weight = loaded_weight.t()
643
644
645
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

646
    def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor):
647
648
649
650
651
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)
652
        param.load_column_parallel_weight(loaded_weight=loaded_weight, is_quantization=self.is_quantization)
653

654
    def forward(
655
656
        self,
        input_,
657
    ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
zhuwenwen's avatar
zhuwenwen committed
658
659
660
661
662
663
664
665
666
        bias = self.bias if not self.skip_bias_add else None

        # Matrix multiply.
        assert self.quant_method is not None
        output_parallel = self.quant_method.apply(self, input_, bias)

        if self.gather_output and self.tp_size > 1:
            # All-gather across the partitions.
            output = tensor_model_parallel_all_gather(output_parallel)
667
        else:
zhuwenwen's avatar
zhuwenwen committed
668
            output = output_parallel
669

zhuwenwen's avatar
zhuwenwen committed
670
671
        if not self.return_bias:
            return output
672
        output_bias = self.bias if self.skip_bias_add else None
zhuwenwen's avatar
zhuwenwen committed
673
        return output, output_bias
674

675
676
677
678
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size_per_partition}"
        s += f", bias={self.bias is not None}"
679
        s += f", tp_size={self.tp_size}"
680
681
682
        s += f", gather_output={self.gather_output}"
        return s

683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701

class MergedColumnParallelLinear(ColumnParallelLinear):
    """Packed linear layers with column parallelism.

    Similar to ColumnParallelLinear, but the weight matrix is concatenated
    along the output dimension. When the weight matrix is loaded, the
    different partitions are sharded separately.

    Args:
        input_size: input dimension of the linear layer.
        output_sizes: list of output dimensions of the linear layer.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make the output
                       available to all GPUs, otherwise, every GPU will have
                       its own output.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
702
        quant_config: Quantization configure.
703
704
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
705
        return_bias: If true, return bias together with outputs in forward pass.
706
707
        disable_tp: If true, all weights matrix won't be sharded, this layer
                    will be treated as a "Replicated" MergedLinear.
708
    """
wujl5's avatar
wujl5 committed
709
710
711
712
713
714
715
716
717
718
    def forward(
        self,
        input_,
        *, iqis: tuple[torch.Tensor, torch.Tensor] | None = None

    ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
        bias = self.bias if not self.skip_bias_add else None

        # Matrix multiply.
        assert self.quant_method is not None
719
720
721
722
723
        if envs.USE_FUSED_RMS_QUANT and iqis is not None:
            print("YYYYY: mlp.gate_up self.quant_method.apply: ", self.quant_method.apply)
            output_parallel = self.quant_method.apply(self, input_, bias, input_quant_args=iqis)
        else:
            output_parallel = self.quant_method.apply(self, input_, bias)
wujl5's avatar
wujl5 committed
724
725
726
727
728
729
730
731
732
733
734

        if self.gather_output and self.tp_size > 1:
            # All-gather across the partitions.
            output = tensor_model_parallel_all_gather(output_parallel)
        else:
            output = output_parallel

        if not self.return_bias:
            return output
        output_bias = self.bias if self.skip_bias_add else None
        return output, output_bias
735

736
737
738
739
740
741
742
    def __init__(
        self,
        input_size: int,
        output_sizes: list[int],
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
743
744
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
745
746
747
        prefix: str = "",
        *,
        return_bias: bool = True,
748
        disable_tp: bool = False,
749
    ):
750
        self.output_sizes = output_sizes
751
752
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
753

754
755
756
757
758
759
760
761
762
763
764
765
766
        assert all(output_size % self.tp_size == 0 for output_size in output_sizes)
        super().__init__(
            input_size=input_size,
            output_size=sum(output_sizes),
            bias=bias,
            gather_output=gather_output,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            quant_config=quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
767
        self.is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
James Fleming's avatar
James Fleming committed
768

769
770
771
772
    def weight_loader(
        self,
        param: Parameter,
        loaded_weight: torch.Tensor,
773
        loaded_shard_id: int | None = None,
774
    ):
775
776
777
778
779
        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
780
781
782
783
784
            if loaded_shard_id is not None:
                param.data[loaded_shard_id].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
                param.shard_weight_type = {
785
                    i: loaded_weight.item() for i, _ in enumerate(self.output_sizes)
786
                }
787
788
            return

789
790
        if is_gguf_weight:
            output_dim = getattr(param, "output_dim", None)
791
792
            shard_size = loaded_weight.size(output_dim) // self.tp_size
            start_idx = self.tp_rank * shard_size
793

794
            if loaded_shard_id is not None:
795
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
796
797
798
799
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
800

801
802
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
803
804
        # Special case for per-tensor scale to load scalar into fused array.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
805

806
        if loaded_shard_id is None:
807
808
            # Loaded weight is already fused on disk (mlp).
            # (e.g., Phi-3's gate_up_proj).
809
            if output_dim is None:
810
                if needs_scalar_to_array:
811
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
812
813
                        param_data, loaded_weight, 0
                    )
814

815
816
817
818
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            current_shard_offset = 0
819
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
820
            shard_offsets: list[tuple[int, int, int]] = []
821
822
823
824
825
            for i, output_size in enumerate(self.output_sizes):
                shard_offsets.append((i, current_shard_offset, output_size))
                current_shard_offset += output_size
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
826
                # Special case for Quantization.
827
828
829
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
830
831
                    shard_size = shard_size // param.packed_factor
                    shard_offset = shard_offset // param.packed_factor
832
                    # Special case for Marlin.
833
                    shard_size, shard_offset = adjust_marlin_shard(
834
835
                        param, shard_size, shard_offset
                    )
836

837
                shard_size, shard_offset = adjust_bitblas_shard(
838
839
                    param, shard_size, shard_offset
                )
840

841
                if use_bitsandbytes_4bit:
842
843
844
845
846
847
848
                    index = list(itertools.accumulate([0] + self.output_sizes))
                    orig_offsets = {
                        str(i): (index[i], size)
                        for i, size in enumerate(self.output_sizes)
                    }
                    orig_offsets["total"] = (self.output_size, 0)
                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
849
850
                        param, orig_offsets, str(shard_id)
                    )
851

852
                loaded_weight_shard = loaded_weight.narrow(
853
854
                    output_dim, shard_offset, shard_size
                )
855
856
857
858
859
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id < len(self.output_sizes)
        if output_dim is not None:
860
861
862
863
864
865
866
867
868
869
870
871
            shard_offset = sum(self.output_sizes[:loaded_shard_id])
            shard_size = self.output_sizes[loaded_shard_id]

            if isinstance(param, BlockQuantScaleParameter):
                weight_block_size = getattr(self, "weight_block_size", None)
                shard_size, shard_offset = adjust_block_scale_shard(
                    weight_block_size, shard_size, shard_offset
                )

            shard_offset //= self.tp_size
            shard_size //= self.tp_size

872
            # Special case for quantization.
873
874
875
876
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
877
878
                shard_size = shard_size // param.packed_factor
                shard_offset = shard_offset // param.packed_factor
879
                # Special case for Marlin.
880
                shard_size, shard_offset = adjust_marlin_shard(
881
882
                    param, shard_size, shard_offset
                )
883
            shard_size, shard_offset = adjust_bitblas_shard(
884
885
                param, shard_size, shard_offset
            )
gaoqiong's avatar
gaoqiong committed
886

887
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
888
889
890
891
892
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

893
            if use_bitsandbytes_4bit:
894
                shard_size = loaded_weight.shape[output_dim]
895
                shard_offset = loaded_weight.shape[output_dim] * loaded_shard_id
896

897
            if not envs.VLLM_USE_NN or self.is_quantization or (envs.VLLM_USE_NN and param_data.dim()==1):
898
899
900
                param_data = param_data.narrow(output_dim, shard_offset, shard_size)
            else:
                param_data = param_data.narrow(int(not(output_dim)), shard_offset, shard_size)
901
            start_idx = self.tp_rank * shard_size
902
            if not is_sharded_weight:
903
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
904
905
906
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
907
908
                param_data, loaded_weight, loaded_shard_id
            )
909

910
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
911
912
913
914
915
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "MergedColumnParallelLinear, assume the weight is "
916
917
                    "the same for all partitions."
                )
918

919
        if envs.VLLM_USE_NN and not self.is_quantization:
920
921
            loaded_weight = loaded_weight.t()
            
gaoqiong's avatar
gaoqiong committed
922
923
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)
924

925
926
927
    def _load_fused_module_from_checkpoint(
        self, param: BasevLLMParameter, loaded_weight: torch.Tensor
    ):
928
929
930
        """
        Handle special case for models where MLP layers are already
        fused on disk. In this case, we have no shard id. This function
931
        determines the shard id by splitting these layers and then calls
932
933
934
935
936
937
938
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """

        current_shard_offset = 0
939
        shard_offsets: list[tuple[int, int, int]] = []
940
941
942
943
944
945
946
947
        for i, output_size in enumerate(self.output_sizes):
            shard_offsets.append((i, current_shard_offset, output_size))
            current_shard_offset += output_size

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
948
949
950
951
952
953
954
955
956
957
958
            if (
                isinstance(param, (PackedColumnParameter, PackedvLLMParameter))
                and param.packed_dim == param.output_dim
            ):
                shard_size, shard_offset = param.adjust_shard_indexes_for_packing(
                    shard_size=shard_size, shard_offset=shard_offset
                )

            loaded_weight_shard = loaded_weight.narrow(
                param.output_dim, shard_offset, shard_size
            )
959
960
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

961
962
963
964
    def weight_loader_v2(
        self,
        param: BasevLLMParameter,
        loaded_weight: torch.Tensor,
965
        loaded_shard_id: int | None = None,
966
    ):
967
        if loaded_shard_id is None:
968
            if isinstance(param, PerTensorScaleParameter):
969
                param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0)
970
                return
971
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
972
                param.load_merged_column_weight(loaded_weight=loaded_weight)
973
                return
974
            # TODO: @dsikka - move to parameter.py
975
976
977
978
979
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id < len(self.output_sizes)

980
981
982
        shard_offset = sum(self.output_sizes[:loaded_shard_id])
        shard_size = self.output_sizes[loaded_shard_id]

983
        if isinstance(param, BlockQuantScaleParameter):
984
985
986
            weight_block_size = getattr(self, "weight_block_size", None)
            shard_size, shard_offset = adjust_block_scale_shard(
                weight_block_size, shard_size, shard_offset
987
            )
988
989
990

        shard_offset //= self.tp_size
        shard_size //= self.tp_size
991

992
993
994
995
996
997
        param.load_merged_column_weight(
            loaded_weight=loaded_weight,
            shard_id=loaded_shard_id,
            shard_offset=shard_offset,
            shard_size=shard_size,
            tp_rank=self.tp_rank,
998
            is_quantization=self.is_quantization
999
        )
1000

1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022

class QKVParallelLinear(ColumnParallelLinear):
    """Linear layers for the attention's QKV transformation.

    Linear layers for the linear transformation of the query, key, and value
    vectors in the attention layer. The weight matrix is concatenated along
    the output dimension. The layer is parallelized along the head dimension.
    When the number of key/value heads is smaller than the number of query
    heads (e.g., multi-query/grouped-query attention), the key/value head may
    be replicated while the query heads are partitioned.

    Args:
        hidden_size: input hidden state size of the transformer.
        head_size: size of each attention head.
        total_num_heads: total number of attention query heads.
        total_num_kv_heads: total number of attention key/value heads. If
                            None, assume total_num_kv_heads = total_num_heads.
        bias: If true, add bias.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
1023
        quant_config: Quantization configure.
1024
1025
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
1026
        return_bias: If true, return bias together with outputs in forward pass.
1027
        disable_tp: If true, weights matrix won't be sharded through tp rank.
1028
1029
    """

1030
1031
1032
1033
1034
    def __init__(
        self,
        hidden_size: int,
        head_size: int,
        total_num_heads: int,
1035
        total_num_kv_heads: int | None = None,
1036
1037
        bias: bool = True,
        skip_bias_add: bool = False,
1038
1039
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
1040
1041
1042
        prefix: str = "",
        *,
        return_bias: bool = True,
1043
        disable_tp: bool = False,
1044
        v_head_size: int | None = None,
1045
    ):
1046
1047
        self.hidden_size = hidden_size
        self.head_size = head_size
1048
        self.v_head_size = v_head_size if v_head_size is not None else head_size
1049
1050
1051
1052
1053
        self.total_num_heads = total_num_heads
        if total_num_kv_heads is None:
            total_num_kv_heads = total_num_heads
        self.total_num_kv_heads = total_num_kv_heads
        # Divide the weight matrix along the last dimension.
1054
        tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
1055
1056
1057
        self.num_heads = divide(self.total_num_heads, tp_size)
        if tp_size >= self.total_num_kv_heads:
            self.num_kv_heads = 1
1058
            self.num_kv_head_replicas = divide(tp_size, self.total_num_kv_heads)
1059
1060
1061
1062
        else:
            self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
            self.num_kv_head_replicas = 1
        input_size = self.hidden_size
1063
        output_size = (
1064
1065
1066
1067
            self.num_heads * self.head_size
            + self.num_kv_heads * self.head_size
            + self.num_kv_heads * self.v_head_size
        ) * tp_size
1068
1069
1070
        self.output_sizes = [
            self.num_heads * self.head_size * tp_size,  # q_proj
            self.num_kv_heads * self.head_size * tp_size,  # k_proj
1071
            self.num_kv_heads * self.v_head_size * tp_size,  # v_proj
James Fleming's avatar
James Fleming committed
1072
        ]
gaoqiong's avatar
gaoqiong committed
1073

1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
        super().__init__(
            input_size=input_size,
            output_size=output_size,
            bias=bias,
            gather_output=False,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            quant_config=quant_config,
            prefix=prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
1086
        self.is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
1087

1088
1089
1090
1091
1092
    def _get_shard_offset_mapping(self, loaded_shard_id: str):
        shard_offset_mapping = {
            "q": 0,
            "k": self.num_heads * self.head_size,
            "v": (self.num_heads + self.num_kv_heads) * self.head_size,
1093
1094
            "total": (self.num_heads + self.num_kv_heads) * self.head_size
            + self.num_kv_heads * self.v_head_size,
1095
1096
1097
1098
1099
1100
1101
        }
        return shard_offset_mapping.get(loaded_shard_id)

    def _get_shard_size_mapping(self, loaded_shard_id: str):
        shard_size_mapping = {
            "q": self.num_heads * self.head_size,
            "k": self.num_kv_heads * self.head_size,
1102
            "v": self.num_kv_heads * self.v_head_size,
1103
1104
1105
        }
        return shard_size_mapping.get(loaded_shard_id)

1106
1107
1108
    def _load_fused_module_from_checkpoint(
        self, param: BasevLLMParameter, loaded_weight: torch.Tensor
    ):
1109
        """
1110
        Handle special case for models where QKV layers are already
1111
        fused on disk. In this case, we have no shard id. This function
1112
        determines the shard id by splitting these layers and then calls
1113
1114
1115
1116
1117
1118
1119
1120
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """
        shard_offsets = [
            # (shard_id, shard_offset, shard_size)
            ("q", 0, self.total_num_heads * self.head_size),
1121
1122
1123
1124
1125
1126
1127
1128
            (
                "k",
                self.total_num_heads * self.head_size,
                self.total_num_kv_heads * self.head_size,
            ),
            (
                "v",
                (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
1129
                self.total_num_kv_heads * self.v_head_size,
1130
            ),
1131
1132
1133
1134
1135
1136
        ]

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
            if (
                isinstance(param, (PackedColumnParameter, PackedvLLMParameter))
                and param.packed_dim == param.output_dim
            ):
                shard_size, shard_offset = param.adjust_shard_indexes_for_packing(
                    shard_size=shard_size, shard_offset=shard_offset
                )

            loaded_weight_shard = loaded_weight.narrow(
                param.output_dim, shard_offset, shard_size
            )
1148
1149
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

1150
1151
1152
1153
    def weight_loader_v2(
        self,
        param: BasevLLMParameter,
        loaded_weight: torch.Tensor,
1154
        loaded_shard_id: str | None = None,
1155
    ):
1156
        if loaded_shard_id is None:  # special case for certain models
1157
            if isinstance(param, PerTensorScaleParameter):
1158
1159
1160
                param.load_qkv_weight(
                    loaded_weight=loaded_weight, shard_id=0, tp_rank=self.tp_rank
                )
1161
                return
1162
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
1163
                param.load_qkv_weight(loaded_weight=loaded_weight, tp_rank=self.tp_rank)
1164
                return
1165
            # TODO: @dsikka - move to parameter.py
1166
1167
1168
1169
1170
1171
1172
1173
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id in ["q", "k", "v"]

        shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
        shard_size = self._get_shard_size_mapping(loaded_shard_id)

1174
        if isinstance(param, BlockQuantScaleParameter):
1175
1176
1177
1178
            weight_block_size = getattr(self, "weight_block_size", None)
            shard_size, shard_offset = adjust_block_scale_shard(
                weight_block_size, shard_size, shard_offset
            )
1179

1180
1181
1182
1183
1184
1185
1186
        param.load_qkv_weight(
            loaded_weight=loaded_weight,
            num_heads=self.num_kv_head_replicas,
            shard_id=loaded_shard_id,
            shard_offset=shard_offset,
            shard_size=shard_size,
            tp_rank=self.tp_rank,
1187
            is_quantization=self.is_quantization, 
1188
        )
1189

1190
1191
1192
1193
    def weight_loader(
        self,
        param: Parameter,
        loaded_weight: torch.Tensor,
1194
        loaded_shard_id: str | None = None,
1195
    ):
1196
1197
1198
1199
        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
1200
        if is_gguf_weight_type:
1201
            idx_map = {"q": 0, "k": 1, "v": 2}
1202
1203
1204
1205
            if loaded_shard_id is not None:
                param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
1206
                param.shard_weight_type = {k: loaded_weight.item() for k in idx_map}
1207
1208
            return

1209
1210
        if is_gguf_weight:
            output_dim = getattr(param, "output_dim", None)
1211
1212
            shard_size = loaded_weight.size(output_dim) // self.tp_size
            start_idx = self.tp_rank * shard_size
1213

1214
            if loaded_shard_id is not None:
1215
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
1216
1217
1218
1219
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
1220

1221
1222
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
1223

1224
1225
        # Special case for per-tensor scales in fused case.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
1226

1227
        if loaded_shard_id is None:
1228
1229
            # Loaded weight is already fused on disk (qkv).
            # (e.g., Phi-3's qkv_proj).
1230
            if output_dim is None:
1231
                if needs_scalar_to_array:
1232
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
1233
1234
                        param_data, loaded_weight, 0
                    )
1235

1236
1237
1238
1239
1240
1241
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            shard_offsets = [
                # (shard_id, shard_offset, shard_size)
                ("q", 0, self.total_num_heads * self.head_size),
1242
1243
1244
1245
1246
1247
1248
1249
                (
                    "k",
                    self.total_num_heads * self.head_size,
                    self.total_num_kv_heads * self.head_size,
                ),
                (
                    "v",
                    (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
1250
                    self.total_num_kv_heads * self.v_head_size,
1251
                ),
1252
            ]
1253
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1254

1255
1256
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
1257
                # Special case for Quantized Weights.
1258
1259
1260
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
1261
1262
                    shard_size = shard_size // param.packed_factor
                    shard_offset = shard_offset // param.packed_factor
1263

1264
                    # Special case for Marlin.
1265
                    shard_size, shard_offset = adjust_marlin_shard(
1266
1267
                        param, shard_size, shard_offset
                    )
1268

1269
1270
1271
                if use_bitsandbytes_4bit:
                    orig_qkv_offsets = {
                        "q": (0, self.total_num_heads * self.head_size),
1272
1273
1274
1275
1276
1277
1278
                        "k": (
                            self.total_num_heads * self.head_size,
                            self.total_num_kv_heads * self.head_size,
                        ),
                        "v": (
                            (self.total_num_heads + self.total_num_kv_heads)
                            * self.head_size,
1279
                            self.total_num_kv_heads * self.v_head_size,
1280
1281
                        ),
                        "total": (
1282
1283
1284
                            (self.total_num_heads + self.total_num_kv_heads)
                            * self.head_size
                            + self.total_num_kv_heads * self.v_head_size,
1285
1286
                            0,
                        ),
1287
1288
1289
                    }

                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
1290
1291
                        param, orig_qkv_offsets, shard_id
                    )
1292

1293
                loaded_weight_shard = loaded_weight.narrow(
1294
1295
                    output_dim, shard_offset, shard_size
                )
1296
1297
1298
1299
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id in ["q", "k", "v"]
1300
1301

        # If output dim is defined, use the default loading process.
1302
1303
1304
1305
1306
1307
1308
1309
        if output_dim is not None:
            if loaded_shard_id == "q":
                shard_offset = 0
                shard_size = self.num_heads * self.head_size
            elif loaded_shard_id == "k":
                shard_offset = self.num_heads * self.head_size
                shard_size = self.num_kv_heads * self.head_size
            elif loaded_shard_id == "v":
1310
                shard_offset = (self.num_heads + self.num_kv_heads) * self.head_size
1311
                shard_size = self.num_kv_heads * self.v_head_size
1312
1313
1314
1315
1316
1317
1318

            if isinstance(param, BlockQuantScaleParameter):
                weight_block_size = getattr(self, "weight_block_size", None)
                shard_size, shard_offset = adjust_block_scale_shard(
                    weight_block_size, shard_size, shard_offset
                )

1319
            # Special case for Quantized Weights.
1320
1321
1322
1323
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
1324
1325
                shard_size = shard_size // param.packed_factor
                shard_offset = shard_offset // param.packed_factor
1326

1327
                # Special case for Marlin.
1328
                shard_size, shard_offset = adjust_marlin_shard(
1329
1330
                    param, shard_size, shard_offset
                )
1331

1332
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1333
1334
1335
1336
1337
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

1338
            if use_bitsandbytes_4bit:
1339
1340
                orig_qkv_offsets = {
                    "q": (0, self.num_heads * self.head_size),
1341
1342
1343
1344
1345
1346
                    "k": (
                        self.num_heads * self.head_size,
                        self.num_kv_heads * self.head_size,
                    ),
                    "v": (
                        (self.num_heads + self.num_kv_heads) * self.head_size,
1347
                        self.num_kv_heads * self.v_head_size,
1348
1349
                    ),
                    "total": (
1350
1351
                        (self.num_heads + self.num_kv_heads) * self.head_size
                        + self.num_kv_heads * self.v_head_size,
1352
1353
                        0,
                    ),
1354
                }
1355
                shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
1356
1357
                    param, orig_qkv_offsets, loaded_shard_id
                )
gaoqiong's avatar
gaoqiong committed
1358

1359
            if not envs.VLLM_USE_NN or len(param_data.shape)==1 or self.is_quantization:
1360
                param_data = param_data.narrow(output_dim, shard_offset, shard_size)
1361
1362
1363
            else:
                param_data = param_data.narrow(int(not(output_dim)), shard_offset,
                                               shard_size)
zhuwenwen's avatar
zhuwenwen committed
1364
            if loaded_shard_id == "q":
1365
                shard_rank = self.tp_rank
1366
            else:
1367
1368
                shard_rank = self.tp_rank // self.num_kv_head_replicas
            start_idx = shard_rank * shard_size
1369

1370
            if not is_sharded_weight:
1371
                loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
1372

1373
1374
1375
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
1376
1377
                param_data, loaded_weight, loaded_shard_id
            )
1378
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
1379
1380
1381
1382
1383
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "QKVParallelLinear, assume the weight is the same "
1384
1385
                    "for all partitions."
                )
gaoqiong's avatar
gaoqiong committed
1386

1387
        if envs.VLLM_USE_NN and not self.is_quantization:
1388
            loaded_weight = loaded_weight.t()
gaoqiong's avatar
gaoqiong committed
1389
1390
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)
1391
1392


1393
# --8<-- [start:row_parallel_linear]
1394
@CustomOp.register("row_parallel_linear")
1395
class RowParallelLinear(LinearBase):
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
    """Linear layer with row parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its first dimension and X along its second dimension as:
               -   -
              | A_1 |
              | .   |
          A = | .   |        X = [X_1, ..., X_p]
              | .   |
              | A_p |
               -   -
    Arguments:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias. Note that bias is not parallelized.
        input_is_parallel: If true, we assume that the input is already
                           split across the GPUs and we do not split
                           again.
        skip_bias_add: This was added to enable performance optimization where
                       bias can be fused with other element-wise operations.
                       We skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
1418
1419
1420
        reduce_results: If true, call all-reduce on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y = X_iA_i
1421
        quant_config: Quantization configure.
1422
1423
1424
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.down_proj)
        return_bias: If true, return bias together with outputs in forward pass.
1425
        disable_tp: If true, weights matrix won't be sharded through tp rank.
1426
1427
    """

1428
1429
    # --8<-- [end:row_parallel_linear]

1430
1431
1432
1433
1434
1435
1436
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        input_is_parallel: bool = True,
        skip_bias_add: bool = False,
1437
        params_dtype: torch.dtype | None = None,
1438
        reduce_results: bool = True,
1439
        quant_config: QuantizationConfig | None = None,
1440
1441
1442
        prefix: str = "",
        *,
        return_bias: bool = True,
1443
        disable_tp: bool = False,
1444
    ):
1445
        # Divide the weight matrix along the first dimension.
1446
1447
        self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
        self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
1448
1449
1450
1451
        self.input_size_per_partition = divide(input_size, self.tp_size)
        self.output_size_per_partition = output_size
        self.output_partition_sizes = [output_size]

1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
        super().__init__(
            input_size,
            output_size,
            skip_bias_add,
            params_dtype,
            quant_config,
            prefix,
            return_bias=return_bias,
            disable_tp=disable_tp,
        )
1462

1463
1464
1465
        self.input_is_parallel = input_is_parallel
        self.reduce_results = reduce_results

1466
        assert self.quant_method is not None
1467
1468
1469
        self.quant_method.create_weights(
            layer=self,
            input_size_per_partition=self.input_size_per_partition,
1470
            output_partition_sizes=self.output_partition_sizes,
1471
1472
1473
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
1474
            weight_loader=(
1475
1476
1477
1478
1479
                self.weight_loader_v2
                if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
                else self.weight_loader
            ),
        )
1480
        if not reduce_results and (bias and not skip_bias_add):
1481
1482
1483
1484
            raise ValueError(
                "When not reduce the results, adding bias to the "
                "results can lead to incorrect results"
            )
1485
1486

        if bias:
1487
1488
1489
1490
1491
1492
1493
1494
            self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype))
            set_weight_attrs(
                self.bias,
                {
                    "output_dim": 0,
                    "weight_loader": self.weight_loader,
                },
            )
1495
1496
        else:
            self.register_parameter("bias", None)
1497
        self.update_param_tp_status()
1498
        self.is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
1499
1500
1501

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        input_dim = getattr(param, "input_dim", None)
1502
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1503
1504
1505
1506
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
1507
1508
1509
1510
1511
1512
1513
1514
1515

        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
1516
1517
            weight_shape = list(loaded_weight.shape)
            if input_dim:
1518
                weight_shape[input_dim] = weight_shape[input_dim] // self.tp_size
1519
            param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
1520

1521
        param_data = param.data
1522
        if input_dim is not None and not is_sharded_weight:
1523
            if not envs.VLLM_USE_NN or self.is_quantization:
1524
1525
1526
                shard_size = param_data.shape[input_dim]
            else:
                shard_size = param_data.shape[int(not(input_dim))]
1527
            start_idx = self.tp_rank * shard_size
1528
            loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)
1529

1530
1531
1532
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
1533
1534
            loaded_weight = loaded_weight.reshape(1)

1535
        if envs.VLLM_USE_NN and not self.is_quantization:
1536
            loaded_weight = loaded_weight.t()
1537
1538
1539
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

1540
    def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor):
1541
1542
1543
1544
1545
1546
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)

1547
        param.load_row_parallel_weight(loaded_weight=loaded_weight, is_quantization=self.is_quantization)
1548

1549
    def forward(
1550
1551
        self,
        input_,
1552
    ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
1553
1554
1555
1556
        if self.input_is_parallel:
            input_parallel = input_
        else:
            splitted_input = split_tensor_along_last_dim(
1557
1558
                input_, num_partitions=self.tp_size
            )
1559
            input_parallel = splitted_input[self.tp_rank].contiguous()
1560
1561

        # Matrix multiply.
1562
        assert self.quant_method is not None
1563
1564
1565
        # Only fuse bias add into GEMM for rank 0 (this ensures that
        # bias will not get added more than once in TP>1 case)
        bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
1566
        output_parallel = self.quant_method.apply(self, input_parallel, bias_)
1567

1568
        if self.reduce_results and self.tp_size > 1:
zhuwenwen's avatar
zhuwenwen committed
1569
            output = tensor_model_parallel_all_reduce(output_parallel)
1570
        else:
1571
1572
            output = output_parallel

1573
1574
        if not self.return_bias:
            return output
1575
        output_bias = self.bias if self.skip_bias_add else None
1576
        return output, output_bias
1577
1578

    def extra_repr(self) -> str:
1579
        s = f"in_features={self.input_size_per_partition}"
1580
1581
1582
1583
1584
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={self.tp_size}"
        s += f", reduce_results={self.reduce_results}"
        return s