linear.py 71.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
from abc import abstractmethod
6
from typing import Any, Literal, Optional, Union
7
import vllm.envs as envs
8
import torch
9
import torch.nn as nn
10
from torch.nn.parameter import Parameter, UninitializedParameter
11

12
13
14
15
16
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size,
                              split_tensor_along_last_dim,
                              tensor_model_parallel_all_gather,
                              tensor_model_parallel_all_reduce)
17
from vllm.logger import init_logger
18
from vllm.model_executor.custom_op import CustomOp
19
20
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig, QuantizeMethodBase)
21
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
22
# yapf: disable
23
from vllm.model_executor.parameter import (BasevLLMParameter,
24
                                           BlockQuantScaleParameter,
25
                                           ModelWeightParameter,
26
                                           PackedColumnParameter,
27
                                           PackedvLLMParameter,
28
29
                                           PerTensorScaleParameter,
                                           RowvLLMParameter)
30
# yapf: enable
31
from vllm.model_executor.utils import set_weight_attrs
32
from vllm.platforms import current_platform
33
from vllm.utils import GiB_bytes
gaoqiong's avatar
gaoqiong committed
34

zhuwenwen's avatar
zhuwenwen committed
35
import os
36
from vllm.model_executor.utils import gemm_bank_conf
37

38

39
40
logger = init_logger(__name__)

41
WEIGHT_LOADER_V2_SUPPORTED = [
42
    "UnquantizedLinearMethod",
43
    "CompressedTensorsLinearMethod",
44
    "CompressedTensorsLinearTransformMethod",
45
46
    "BitBLASLinearMethod",
    "GPTQBitBLASLinearMethod",
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
    "AWQMarlinLinearMethod",
    "AWQLinearMethod",
    "GPTQMarlinLinearMethod",
    "Fp8LinearMethod",
    "MarlinLinearMethod",
    "GPTQMarlin24LinearMethod",
    "TPUInt8LinearMethod",
    "GPTQLinearMethod",
    "FBGEMMFp8LinearMethod",
    "ModelOptFp8LinearMethod",
    "IPEXAWQLinearMethod",
    "IPEXGPTQLinearMethod",
    "HQQMarlinMethod",
    "QuarkLinearMethod",
    "ModelOptNvFp4LinearMethod",
62
    "PetitNvFp4LinearMethod",
zhuwenwen's avatar
zhuwenwen committed
63
    "BlockInt8LinearMethod",
64
]
65

66

67
68
69
70
71
72
73
74
75
def adjust_bitblas_shard(param, shard_size, shard_offset):
    bitblas_tile_size = getattr(param, "bitblas_tile_size", None)
    if bitblas_tile_size is not None:
        return (shard_size // bitblas_tile_size,
                shard_offset // bitblas_tile_size)

    return shard_size, shard_offset


76
77
78
79
80
81
82
83
def adjust_marlin_shard(param, shard_size, shard_offset):
    marlin_tile_size = getattr(param, "marlin_tile_size", None)
    if marlin_tile_size is None:
        return shard_size, shard_offset

    return shard_size * marlin_tile_size, shard_offset * marlin_tile_size


84
def adjust_bitsandbytes_4bit_shard(param: Parameter,
85
86
                                   shard_offsets: dict[str, tuple[int, int]],
                                   loaded_shard_id: str) -> tuple[int, int]:
87
88
    """Adjust the quantization offsets and sizes for BitsAndBytes sharding."""

89
90
    total, _ = shard_offsets["total"]
    orig_offset, orig_size = shard_offsets[loaded_shard_id]
91
92
93
94
95
96
97
98

    quantized_total = param.data.shape[0]
    quantized_offset = orig_offset * quantized_total // total
    quantized_size = orig_size * quantized_total // total

    return quantized_size, quantized_offset


99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
    """For fused modules (QKV and MLP) we have an array of length
    N that holds 1 scale for each "logical" matrix. So the param
    is an array of length N. The loaded_weight corresponds to 
    one of the shards on disk. Here, we slice the param based on 
    the shard_id for loading.
    """
    qkv_idxs = {"q": 0, "k": 1, "v": 2}

    if isinstance(shard_id, str):
        shard_id = qkv_idxs[shard_id]
    elif not isinstance(shard_id, int):
        raise ValueError(f"Unknown Shard Id {shard_id}")

    # AutoFP8 scales do not have a shape
    # compressed-tensors scales do have a shape
    if len(loaded_weight.shape) != 0:
        assert loaded_weight.shape[0] == 1
        loaded_weight = loaded_weight[0]

119
120
121
122
    if envs.VLLM_USE_NN:
        return param[shard_id], loaded_weight.t()
    else:
        return param[shard_id], loaded_weight
123
124


125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
# TODO(Isotr0py): We might need a more flexible structure to handle
# bitsandbytes shard offsets.
def left_shift_bitsandbytes_4bit_shard(bnb_weight_attrs: dict[str, Any]):
    """
    Separate the BitsAndBytes 4-bit shard.

    For example, given bnb weight attributes as below:
    {
        'bnb_shard_offsets': array([0, 4, 8, 16]), 
        'bnb_quant_state': {0: ..., 1: ..., 2: ...},
    }

    The function will return:
    {
        'bnb_shard_offsets': array([0, 4]), 
        'bnb_quant_state': {0: ...},
    }
    and
    {
        'bnb_shard_offsets': array([0, 4, 12]),
        'bnb_quant_state': {0: ..., 1: ...},
    }
    """
    shard_offsets = bnb_weight_attrs["bnb_shard_offsets"]
    offset_l = shard_offsets[:2]
    offset_r = shard_offsets[1:] - shard_offsets[1]
    quant_state_l = {0: bnb_weight_attrs["bnb_quant_state"][0]}
    quant_state_r = {
        i - 1: bnb_weight_attrs["bnb_quant_state"][i]
        for i in range(1,
                       len(shard_offsets) - 1)
    }
    left = dict(bnb_shard_offsets=offset_l, bnb_quant_state=quant_state_l)
    right = dict(bnb_shard_offsets=offset_r, bnb_quant_state=quant_state_r)
    return left, right


162
class LinearMethodBase(QuantizeMethodBase):
163
164
165
    """Base class for different (maybe quantized) linear methods."""

    @abstractmethod
166
167
    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
168
                       output_partition_sizes: list[int], input_size: int,
169
170
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
171
172
        """Create weights for a linear layer. 
           The weights will be set as attributes of the layer.
173

174
175
176
177
178
179
180
181
182
183
        Args:
            layer: The layer that is using the LinearMethodBase factory.
            input_size_per_partition: Size of the weight input dim on rank X.
            output_partition_sizes: Sizes of the output dim of each logical 
                weight on rank X. E.g., output_partition_sizes for QKVLinear
                is a list contains the width of Wq, Wk, Wv on rank X.
            input_size: Size of the input dim of the weight across all ranks.
            output_size: Size of the output dim of the weight across all ranks.
            params_dtype: Datatype of the parameters.
        """
184
185
186
        raise NotImplementedError

    @abstractmethod
187
188
189
190
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
191
192
        """Apply the weights in layer to the input tensor.
        Expects create_weights to have been called before on the layer."""
193
194
195
196
        raise NotImplementedError


class UnquantizedLinearMethod(LinearMethodBase):
197
    """Linear method without quantization."""
198
199
    
    def __init__(self):
zhuwenwen's avatar
zhuwenwen committed
200
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
201
        self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
gaoqiong's avatar
gaoqiong committed
202
        
203
204
    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
205
                       output_partition_sizes: list[int], input_size: int,
206
207
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
208
209
210
211
212
        # This method creates unquantized linear weights.
        # The weights are not quantized, and they are not sharded.
        # The amount of memory allocated for the weights is
        # sum(output_partition_sizes) * input_size_per_partition.
        try:
213
            weight_loader = extra_weight_attrs.pop("weight_loader")
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
            if envs.VLLM_USE_NN:
                weight = ModelWeightParameter(data=torch.empty(
                    input_size_per_partition,
                    sum(output_partition_sizes),
                    dtype=params_dtype),
                                            input_dim=1,
                                            output_dim=0,
                                            weight_loader=weight_loader)
            else:
                weight = ModelWeightParameter(data=torch.empty(
                    sum(output_partition_sizes),
                    input_size_per_partition,
                    dtype=params_dtype),
                                            input_dim=1,
                                            output_dim=0,
                                            weight_loader=weight_loader)
230
231
232
233
234
235
236
237
238
239
240
241
        except torch.cuda.OutOfMemoryError as e:
            logger.error("Failed to create unquantized linear weights: %s", e)
            if torch.cuda.is_available():
                logger.debug("CUDA device: %s", torch.cuda.current_device())
                logger.debug("Allocated: %.2f GiB",
                             torch.cuda.memory_allocated() / GiB_bytes)
                logger.debug("Reserved: %.2f GiB",
                             torch.cuda.memory_reserved() / GiB_bytes)
            raise RuntimeError(
                "Failed to create unquantized linear weights. "
                "This may be caused by insufficient memory to allocate "
                "the weight.") from e
242

243
244
        layer.register_parameter("weight", weight)
        set_weight_attrs(weight, extra_weight_attrs)
245

246
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
247
248
249
250
        if current_platform.is_cpu():
            from vllm.model_executor.layers.utils import (
                dispatch_cpu_unquantized_gemm)
            dispatch_cpu_unquantized_gemm(layer, remove_weight=True)
251

252
253
254
255
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
zhuwenwen's avatar
zhuwenwen committed
256
        if self.use_llama_nn:
zhuwenwen's avatar
zhuwenwen committed
257
258
            # if os.environ['GEMM_PAD'] == '1' and gemm_bank_conf(layer.weight.shape[1] - 32):
            #     layer.weight = layer.weight[:,:-32]
zhuwenwen's avatar
zhuwenwen committed
259
            if bias is not None:
zhuwenwen's avatar
zhuwenwen committed
260
                if len(x.shape) == 2: 
261
                    return torch.addmm(bias, x, layer.weight)
zhuwenwen's avatar
zhuwenwen committed
262
                else:
263
                    return torch.matmul(x, layer.weight) + bias
zhuwenwen's avatar
zhuwenwen committed
264
            else:
265
                return torch.matmul(x, layer.weight)
zhuwenwen's avatar
zhuwenwen committed
266
        else:
zhuwenwen's avatar
zhuwenwen committed
267
268
269
270
271
272
273
274
275
276
            # if envs.VLLM_USE_NN and x.shape[-1] == layer.weight.shape[0]:
            #     return dispatch_unquantized_gemm()(layer, x, layer.weight.t(), bias)
            if envs.VLLM_USE_NN:
                if bias is not None:
                    if len(x.shape) == 2: 
                        return torch.addmm(bias, x, layer.weight)
                    else:
                        return torch.matmul(x, layer.weight) + bias
                else:
                    return torch.matmul(x, layer.weight)
277
            else:
zhuwenwen's avatar
zhuwenwen committed
278
                return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
279

280

281
class LinearBase(CustomOp):
282
    """Base linear layer.
283
284
285
286
287
288

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
289
        quant_config: Quantization configure.
290
        prefix: Prefix for parameter names.
291
        return_bias: If true, return bias together with outputs in forward pass.
292
        disable_tp: If true, tensor parallelism will be disabled for this layer.
293
294
295
296
297
298
299
300
    """

    def __init__(
        self,
        input_size: int,
        output_size: int,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
301
        quant_config: Optional[QuantizationConfig] = None,
302
        prefix: str = "",
303
304
        *,
        return_bias: bool = True,
305
        disable_tp: bool = False,
306
307
308
309
310
311
312
313
314
315
    ):
        super().__init__()

        # Keep input parameters
        self.input_size = input_size
        self.output_size = output_size
        self.skip_bias_add = skip_bias_add
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype
316
317
        self.quant_config = quant_config
        self.prefix = prefix
318
        if quant_config is None:
319
320
            self.quant_method: Optional[
                QuantizeMethodBase] = UnquantizedLinearMethod()
321
        else:
322
323
            self.quant_method = quant_config.get_quant_method(self,
                                                              prefix=prefix)
324
        self.return_bias = return_bias
325
326
327
328
329
330
        self.disable_tp = disable_tp
        self.tp_rank = (get_tensor_model_parallel_rank()
                        if not disable_tp else 0)
        self.tp_size = (get_tensor_model_parallel_world_size()
                        if not disable_tp else 1)

331
    def update_param_tp_status(self):
332
333
334
335
        for param in self.parameters():
            if isinstance(param, BasevLLMParameter):
                param.tp_rank = self.tp_rank
                param.tp_size = self.tp_size
336
337


338
@CustomOp.register("replicated_linear")
339
340
341
342
343
344
345
346
347
348
class ReplicatedLinear(LinearBase):
    """Replicated linear layer.

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        bias: If true, add bias.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
        quant_config: Quantization configure.
349
350
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
351
        return_bias: If true, return bias together with outputs in forward pass.
352
        disable_tp: Take no effect for replicated linear layers.
353
354
    """

355
356
357
358
359
360
361
362
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
363
        eps: Optional[float] = 1e-6,
364
365
366
        prefix: str = "",
        *,
        return_bias: bool = True,
367
        disable_tp: bool = False,
368
    ):
369
370
371
372
373
374
        # If MergedReplicatedLinear, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = self.output_sizes
        else:
            self.output_partition_sizes = [output_size]

375
376
377
378
379
        super().__init__(input_size,
                         output_size,
                         skip_bias_add,
                         params_dtype,
                         quant_config,
380
                         prefix=prefix,
381
382
                         return_bias=return_bias,
                         disable_tp=disable_tp)
383
        self.eps = eps
384

385
386
        # All the linear layer supports quant method.
        assert self.quant_method is not None
387
        self.quant_method.create_weights(self,
388
389
                                         self.input_size,
                                         self.output_partition_sizes,
390
391
392
                                         self.input_size,
                                         self.output_size,
                                         self.params_dtype,
393
                                         weight_loader=self.weight_loader)
394

395
396
        if bias:
            self.bias = Parameter(
397
                torch.empty(self.output_size, dtype=self.params_dtype))
398
399
400
401
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
402
403
        else:
            self.register_parameter("bias", None)
404
        self.is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
405

406
407
408
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        # If the weight on disk does not have a shape, give it one
        # (such scales for AutoFp8).
409
410
411
412
413
414
415
416
417
418
419
        # Special case for GGUF

        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
            param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)

420
421
422
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)

423
        if envs.VLLM_USE_NN and not self.is_quantization:
424
425
            loaded_weight = loaded_weight.t()
            
426
427
428
        assert param.size() == loaded_weight.size(), (
            f"Tried to load weights of size {loaded_weight.size()}"
            f"to a parameter of size {param.size()}")
429
430
        param.data.copy_(loaded_weight)

431
    def forward(
432
        self,
433
        x: torch.Tensor,
434
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
435
436
        bias = self.bias if not self.skip_bias_add else None
        assert self.quant_method is not None
437

438
439
        output = self.quant_method.apply(self, x, bias)
        output_bias = self.bias if self.skip_bias_add else None
440

441
442
443
        if not self.return_bias:
            return output
        return output, output_bias
444

445
446
447
448
449
450
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        return s

451

452
@CustomOp.register("column_parallel_linear")
453
class ColumnParallelLinear(LinearBase):
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
    """Linear layer with column parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its second dimension as A = [A_1, ..., A_p].

    Args:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y_i = XA_i
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
470
        quant_config: Quantization configure.
James Fleming's avatar
James Fleming committed
471
472
        output_sizes: list of output sizes packed into one output, like for QKV
                       the list would be size 3.
473
        prefix: The name of the layer in the state dict, including all parents
474
                        (e.g. model.layers.0.qkv_proj) 
475
476
        return_bias: If true, return bias together with outputs in forward pass.
        disable_tp: If true, weights matrix won't be sharded through tp rank.
477
478
    """

479
480
481
482
483
484
485
486
487
488
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
        output_sizes: Optional[list[int]] = None,
489
        eps: Optional[float] = 1e-6,
490
491
492
        prefix: str = "",
        *,
        return_bias: bool = True,
493
        disable_tp: bool = False,
494
    ):
495
        # Divide the weight matrix along the last dimension.
496
497
498
499
        self.tp_rank = (get_tensor_model_parallel_rank()
                        if not disable_tp else 0)
        self.tp_size = (get_tensor_model_parallel_world_size()
                        if not disable_tp else 1)
500
501
        self.input_size_per_partition = input_size
        self.output_size_per_partition = divide(output_size, self.tp_size)
502
503
504
505
        self.output_partition_sizes = [self.output_size_per_partition]
        # If QKV or MergedColumn, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = [
506
                divide(output_size, self.tp_size)
507
508
509
                for output_size in self.output_sizes
            ]

510
511
512
513
514
515
        super().__init__(input_size,
                         output_size,
                         skip_bias_add,
                         params_dtype,
                         quant_config,
                         prefix,
516
517
                         return_bias=return_bias,
                         disable_tp=disable_tp)
518

519
        self.eps = eps
520
521
        self.gather_output = gather_output

James Fleming's avatar
James Fleming committed
522
523
        if output_sizes is None:
            output_sizes = [output_size]
524

525
        assert self.quant_method is not None
526
527
        self.quant_method.create_weights(
            layer=self,
528
            input_size_per_partition=self.input_size_per_partition,
529
530
531
532
            output_partition_sizes=self.output_partition_sizes,
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
533
534
            weight_loader=(
                self.weight_loader_v2 if self.quant_method.__class__.__name__
535
                in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
536
537
538
539
540
541
542
543
544
545
        if bias:
            self.bias = Parameter(
                torch.empty(self.output_size_per_partition,
                            dtype=params_dtype))
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
        else:
            self.register_parameter("bias", None)
546
        self.update_param_tp_status()
547
        self.is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
548

549
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
550

551
        output_dim = getattr(param, "output_dim", None)
552

553
554
555
556
557
558
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

559
560
561
562
563
564
565
566
        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
567
568
            final_shape = list(loaded_weight.shape)
            if output_dim is not None:
569
570
571
                assert final_shape[output_dim] % self.tp_size == 0
                final_shape[output_dim] = (final_shape[output_dim] //
                                           self.tp_size)
572
            param.materialize(final_shape, dtype=loaded_weight.dtype)
573

574
        param_data = param.data
575
        if output_dim is not None and not is_sharded_weight:
576
            if not envs.VLLM_USE_NN or len(param_data.shape)==1 or self.is_quantization:
577
578
579
                shard_size = param_data.shape[output_dim] 
            else:
                shard_size = param_data.shape[int(not(output_dim))]
580
            start_idx = self.tp_rank * shard_size
581

582
583
            loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                 shard_size)
584
585
586
587
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)
588

589
        if envs.VLLM_USE_NN and not self.is_quantization:
590
591
            loaded_weight = loaded_weight.t()
            
592
593
594
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

595
596
    def weight_loader_v2(self, param: BasevLLMParameter,
                         loaded_weight: torch.Tensor):
597
598
599
600
601
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)
602
        param.load_column_parallel_weight(loaded_weight=loaded_weight, is_quantization=self.is_quantization)
603

604
    def forward(
605
606
        self,
        input_,
607
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
608
609
610
611
612
613
614
615
616
        bias = self.bias if not self.skip_bias_add else None

        # Matrix multiply.
        assert self.quant_method is not None
        output_parallel = self.quant_method.apply(self, input_, bias)

        if self.gather_output and self.tp_size > 1:
            # All-gather across the partitions.
            output = tensor_model_parallel_all_gather(output_parallel)
617
        else:
618
619
620
621
622
            output = output_parallel
        output_bias = self.bias if self.skip_bias_add else None
        if not self.return_bias:
            return output
        return output, output_bias
623

624
625
626
627
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size_per_partition}"
        s += f", bias={self.bias is not None}"
628
        s += f", tp_size={self.tp_size}"
629
630
631
        s += f", gather_output={self.gather_output}"
        return s

632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650

class MergedColumnParallelLinear(ColumnParallelLinear):
    """Packed linear layers with column parallelism.

    Similar to ColumnParallelLinear, but the weight matrix is concatenated
    along the output dimension. When the weight matrix is loaded, the
    different partitions are sharded separately.

    Args:
        input_size: input dimension of the linear layer.
        output_sizes: list of output dimensions of the linear layer.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make the output
                       available to all GPUs, otherwise, every GPU will have
                       its own output.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
651
        quant_config: Quantization configure.
652
653
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
654
        return_bias: If true, return bias together with outputs in forward pass.
655
656
        disable_tp: If true, all weights matrix won't be sharded, this layer
                    will be treated as a "Replicated" MergedLinear.
657
658
    """

659
660
661
662
663
664
665
666
667
    def __init__(
        self,
        input_size: int,
        output_sizes: list[int],
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
668
        eps: Optional[float] = 1e-6,
669
670
671
        prefix: str = "",
        *,
        return_bias: bool = True,
672
        disable_tp: bool = False,
673
    ):
674
        self.eps = eps
675
        self.output_sizes = output_sizes
676
677
678
679
        self.tp_size = (get_tensor_model_parallel_world_size()
                        if not disable_tp else 1)
        self.tp_rank = (get_tensor_model_parallel_rank()
                        if not disable_tp else 0)
680
681
682

        assert all(output_size % self.tp_size == 0
                   for output_size in output_sizes)
683
684
685
686
687
688
        super().__init__(input_size=input_size,
                         output_size=sum(output_sizes),
                         bias=bias,
                         gather_output=gather_output,
                         skip_bias_add=skip_bias_add,
                         params_dtype=params_dtype,
689
                         quant_config=quant_config,
690
                         prefix=prefix,
691
692
                         return_bias=return_bias,
                         disable_tp=disable_tp)
693
        self.is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
694
695
696
697
698

    def weight_loader(self,
                      param: Parameter,
                      loaded_weight: torch.Tensor,
                      loaded_shard_id: Optional[int] = None):
James Fleming's avatar
James Fleming committed
699

700
701
702
703
704
        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
705
706
707
708
709
710
711
712
            if loaded_shard_id is not None:
                param.data[loaded_shard_id].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
                param.shard_weight_type = {
                    i: loaded_weight.item()
                    for i, _ in enumerate(self.output_sizes)
                }
713
714
            return

715
716
        if is_gguf_weight:
            output_dim = getattr(param, "output_dim", None)
717
718
            shard_size = loaded_weight.size(output_dim) // self.tp_size
            start_idx = self.tp_rank * shard_size
719

720
721
722
723
724
725
726
            if loaded_shard_id is not None:
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
727

728
729
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
730
731
        # Special case for per-tensor scale to load scalar into fused array.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
732

733
        if loaded_shard_id is None:
734
735
            # Loaded weight is already fused on disk (mlp).
            # (e.g., Phi-3's gate_up_proj).
736
            if output_dim is None:
737
                if needs_scalar_to_array:
738
739
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
                        param_data, loaded_weight, 0)
740

741
742
743
744
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            current_shard_offset = 0
745
746
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)
747
            shard_offsets: list[tuple[int, int, int]] = []
748
749
750
751
752
            for i, output_size in enumerate(self.output_sizes):
                shard_offsets.append((i, current_shard_offset, output_size))
                current_shard_offset += output_size
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
753
                # Special case for Quantization.
754
755
756
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
757
758
                    shard_size = shard_size // param.packed_factor
                    shard_offset = shard_offset // param.packed_factor
759
                    # Special case for Marlin.
760
761
762
                    shard_size, shard_offset = adjust_marlin_shard(
                        param, shard_size, shard_offset)

763
764
765
                shard_size, shard_offset = adjust_bitblas_shard(
                    param, shard_size, shard_offset)

766
                if use_bitsandbytes_4bit:
767
768
769
770
771
772
773
774
775
                    index = list(itertools.accumulate([0] + self.output_sizes))
                    orig_offsets = {
                        str(i): (index[i], size)
                        for i, size in enumerate(self.output_sizes)
                    }
                    orig_offsets["total"] = (self.output_size, 0)
                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
                        param, orig_offsets, str(shard_id))

776
777
778
779
780
781
782
                loaded_weight_shard = loaded_weight.narrow(
                    output_dim, shard_offset, shard_size)
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id < len(self.output_sizes)
        if output_dim is not None:
783
784
785
            shard_offset = (sum(self.output_sizes[:loaded_shard_id]) //
                            self.tp_size)
            shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
786
            # Special case for quantization.
787
788
789
790
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
791
792
                shard_size = shard_size // param.packed_factor
                shard_offset = shard_offset // param.packed_factor
793
                # Special case for Marlin.
794
795
                shard_size, shard_offset = adjust_marlin_shard(
                    param, shard_size, shard_offset)
796
797
            shard_size, shard_offset = adjust_bitblas_shard(
                param, shard_size, shard_offset)
gaoqiong's avatar
gaoqiong committed
798

799
800
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)
801
802
803
804
805
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

806
            if use_bitsandbytes_4bit:
807
808
809
                shard_size = loaded_weight.shape[output_dim]
                shard_offset = loaded_weight.shape[output_dim] * \
                    loaded_shard_id
810
                    
811
            if not envs.VLLM_USE_NN or self.is_quantization or (envs.VLLM_USE_NN and param_data.dim()==1):
812
813
814
                param_data = param_data.narrow(output_dim, shard_offset, shard_size)
            else:
                param_data = param_data.narrow(int(not(output_dim)), shard_offset, shard_size)
815

816
            start_idx = self.tp_rank * shard_size
817
            if not is_sharded_weight:
818
819
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)
820
821
822
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
823
824
                param_data, loaded_weight, loaded_shard_id)

825
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
826
827
828
829
830
831
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "MergedColumnParallelLinear, assume the weight is "
                    "the same for all partitions.")
832

833
        if envs.VLLM_USE_NN and not self.is_quantization:
834
835
            loaded_weight = loaded_weight.t()
            
gaoqiong's avatar
gaoqiong committed
836
837
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)
838

839
840
841
842
843
    def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
                                           loaded_weight: torch.Tensor):
        """
        Handle special case for models where MLP layers are already
        fused on disk. In this case, we have no shard id. This function
844
        determines the shard id by splitting these layers and then calls
845
846
847
848
849
850
851
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """

        current_shard_offset = 0
852
        shard_offsets: list[tuple[int, int, int]] = []
853
854
855
856
857
858
859
860
        for i, output_size in enumerate(self.output_sizes):
            shard_offsets.append((i, current_shard_offset, output_size))
            current_shard_offset += output_size

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
861
862
            if isinstance(param, (PackedColumnParameter, PackedvLLMParameter
                                  )) and param.packed_dim == param.output_dim:
863
864
                shard_size, shard_offset = \
                    param.adjust_shard_indexes_for_packing(
865
866
867
868
869
870
871
872
873
874
875
                    shard_size=shard_size, shard_offset=shard_offset)

            loaded_weight_shard = loaded_weight.narrow(param.output_dim,
                                                       shard_offset,
                                                       shard_size)
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

    def weight_loader_v2(self,
                         param: BasevLLMParameter,
                         loaded_weight: torch.Tensor,
                         loaded_shard_id: Optional[int] = None):
876
        
877
        if loaded_shard_id is None:
878
879
880
881
            if isinstance(param, PerTensorScaleParameter):
                param.load_merged_column_weight(loaded_weight=loaded_weight,
                                                shard_id=0)
                return
882
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
883
                param.load_merged_column_weight(loaded_weight=loaded_weight)
884
                return
885
            # TODO: @dsikka - move to parameter.py
886
887
888
889
890
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id < len(self.output_sizes)

891
892
        if isinstance(param, BlockQuantScaleParameter):
            assert self.quant_method is not None
893
894
895
            # Assume the weight block size has been set by quant method
            assert hasattr(self, "weight_block_size")
            weight_block_size = self.weight_block_size
896
897
898
899
            assert weight_block_size is not None
            block_n, _ = weight_block_size[0], weight_block_size[1]
            shard_offset = (
                (sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) //
900
                block_n) // self.tp_size
901
            shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) //
902
                          block_n // self.tp_size)
903
        else:
904
905
906
            shard_offset = sum(
                self.output_sizes[:loaded_shard_id]) // self.tp_size
            shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
907
908
909
910
911
912
913
        
        param.load_merged_column_weight(loaded_weight=loaded_weight,
                                        shard_id=loaded_shard_id,
                                        shard_offset=shard_offset,
                                        shard_size=shard_size,
                                        tp_rank=self.tp_rank,
                                        is_quantization=self.is_quantization)
914

915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936

class QKVParallelLinear(ColumnParallelLinear):
    """Linear layers for the attention's QKV transformation.

    Linear layers for the linear transformation of the query, key, and value
    vectors in the attention layer. The weight matrix is concatenated along
    the output dimension. The layer is parallelized along the head dimension.
    When the number of key/value heads is smaller than the number of query
    heads (e.g., multi-query/grouped-query attention), the key/value head may
    be replicated while the query heads are partitioned.

    Args:
        hidden_size: input hidden state size of the transformer.
        head_size: size of each attention head.
        total_num_heads: total number of attention query heads.
        total_num_kv_heads: total number of attention key/value heads. If
                            None, assume total_num_kv_heads = total_num_heads.
        bias: If true, add bias.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
937
        quant_config: Quantization configure.
938
939
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
940
        return_bias: If true, return bias together with outputs in forward pass.
941
        disable_tp: If true, weights matrix won't be sharded through tp rank.
942
943
    """

944
945
946
947
948
949
950
951
952
953
954
955
956
    def __init__(
        self,
        hidden_size: int,
        head_size: int,
        total_num_heads: int,
        total_num_kv_heads: Optional[int] = None,
        bias: bool = True,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        *,
        return_bias: bool = True,
957
        disable_tp: bool = False,
958
    ):
959
960
961
962
963
964
965
        self.hidden_size = hidden_size
        self.head_size = head_size
        self.total_num_heads = total_num_heads
        if total_num_kv_heads is None:
            total_num_kv_heads = total_num_heads
        self.total_num_kv_heads = total_num_kv_heads
        # Divide the weight matrix along the last dimension.
966
967
        tp_size = (get_tensor_model_parallel_world_size()
                   if not disable_tp else 1)
968
969
970
971
972
973
974
975
976
977
978
        self.num_heads = divide(self.total_num_heads, tp_size)
        if tp_size >= self.total_num_kv_heads:
            self.num_kv_heads = 1
            self.num_kv_head_replicas = divide(tp_size,
                                               self.total_num_kv_heads)
        else:
            self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
            self.num_kv_head_replicas = 1
        input_size = self.hidden_size
        output_size = (self.num_heads +
                       2 * self.num_kv_heads) * tp_size * self.head_size
979
980
981
982
        self.output_sizes = [
            self.num_heads * self.head_size * tp_size,  # q_proj
            self.num_kv_heads * self.head_size * tp_size,  # k_proj
            self.num_kv_heads * self.head_size * tp_size,  # v_proj 
James Fleming's avatar
James Fleming committed
983
        ]
gaoqiong's avatar
gaoqiong committed
984

985
986
987
988
989
990
        super().__init__(input_size=input_size,
                         output_size=output_size,
                         bias=bias,
                         gather_output=False,
                         skip_bias_add=skip_bias_add,
                         params_dtype=params_dtype,
991
                         quant_config=quant_config,
992
                         prefix=prefix,
993
994
                         return_bias=return_bias,
                         disable_tp=disable_tp)
995
        self.is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
996

997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
    def _get_shard_offset_mapping(self, loaded_shard_id: str):
        shard_offset_mapping = {
            "q": 0,
            "k": self.num_heads * self.head_size,
            "v": (self.num_heads + self.num_kv_heads) * self.head_size,
            "total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size
        }
        return shard_offset_mapping.get(loaded_shard_id)

    def _get_shard_size_mapping(self, loaded_shard_id: str):
        shard_size_mapping = {
            "q": self.num_heads * self.head_size,
            "k": self.num_kv_heads * self.head_size,
            "v": self.num_kv_heads * self.head_size,
        }
        return shard_size_mapping.get(loaded_shard_id)

    def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
                                           loaded_weight: torch.Tensor):
        """
        Handle special case for models where QKV layers are already 
        fused on disk. In this case, we have no shard id. This function
1019
        determines the shard id by splitting these layers and then calls
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """
        shard_offsets = [
            # (shard_id, shard_offset, shard_size)
            ("q", 0, self.total_num_heads * self.head_size),
            ("k", self.total_num_heads * self.head_size,
             self.total_num_kv_heads * self.head_size),
            ("v",
             (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
             self.total_num_kv_heads * self.head_size),
        ]

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
1039
1040
            if isinstance(param, (PackedColumnParameter, PackedvLLMParameter
                                  )) and param.packed_dim == param.output_dim:
1041
1042
                shard_size, shard_offset = \
                    param.adjust_shard_indexes_for_packing(
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
                    shard_size=shard_size, shard_offset=shard_offset)

            loaded_weight_shard = loaded_weight.narrow(param.output_dim,
                                                       shard_offset,
                                                       shard_size)
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

    def weight_loader_v2(self,
                         param: BasevLLMParameter,
                         loaded_weight: torch.Tensor,
                         loaded_shard_id: Optional[str] = None):
1054
        
1055
        if loaded_shard_id is None:  # special case for certain models
1056
            if isinstance(param, PerTensorScaleParameter):
1057
1058
1059
                param.load_qkv_weight(loaded_weight=loaded_weight,
                                      shard_id=0,
                                      tp_rank=self.tp_rank)
1060
                return
1061
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
1062
1063
                param.load_qkv_weight(loaded_weight=loaded_weight,
                                      tp_rank=self.tp_rank)
1064
                return
1065
            # TODO: @dsikka - move to parameter.py
1066
1067
1068
1069
1070
1071
1072
1073
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id in ["q", "k", "v"]

        shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
        shard_size = self._get_shard_size_mapping(loaded_shard_id)

1074
1075
1076
        # Note(simon): This is needed for Qwen3's fp8 quantization.
        if isinstance(param, BlockQuantScaleParameter):
            assert self.quant_method is not None
1077
1078
1079
1080
            # Assume the weight block size has been set by quant method
            assert hasattr(self, "weight_block_size")
            weight_block_size = self.weight_block_size
            assert weight_block_size is not None
1081
1082
1083
1084
            block_n, _ = weight_block_size[0], weight_block_size[1]
            shard_offset = (shard_offset + block_n - 1) // block_n
            shard_size = (shard_size + block_n - 1) // block_n

1085
1086
1087
1088
1089
1090
1091
        param.load_qkv_weight(loaded_weight=loaded_weight,
                            num_heads=self.num_kv_head_replicas,
                            shard_id=loaded_shard_id,
                            shard_offset=shard_offset,
                            shard_size=shard_size,
                            tp_rank=self.tp_rank,
                            is_quantization=self.is_quantization)
1092

1093
1094
1095
1096
    def weight_loader(self,
                      param: Parameter,
                      loaded_weight: torch.Tensor,
                      loaded_shard_id: Optional[str] = None):
1097
1098
1099
1100
1101

        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
1102
        if is_gguf_weight_type:
1103
            idx_map = {"q": 0, "k": 1, "v": 2}
1104
1105
1106
1107
1108
1109
1110
1111
            if loaded_shard_id is not None:
                param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
                param.shard_weight_type = {
                    k: loaded_weight.item()
                    for k in idx_map
                }
1112
1113
            return

1114
1115
        if is_gguf_weight:
            output_dim = getattr(param, "output_dim", None)
1116
1117
            shard_size = loaded_weight.size(output_dim) // self.tp_size
            start_idx = self.tp_rank * shard_size
1118

1119
1120
1121
1122
1123
1124
1125
            if loaded_shard_id is not None:
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                return
1126

1127
1128
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
1129

1130
1131
        # Special case for per-tensor scales in fused case.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
1132

1133
        if loaded_shard_id is None:
1134
1135
            # Loaded weight is already fused on disk (qkv).
            # (e.g., Phi-3's qkv_proj).
1136
            if output_dim is None:
1137
                if needs_scalar_to_array:
1138
1139
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
                        param_data, loaded_weight, 0)
1140

1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            shard_offsets = [
                # (shard_id, shard_offset, shard_size)
                ("q", 0, self.total_num_heads * self.head_size),
                ("k", self.total_num_heads * self.head_size,
                 self.total_num_kv_heads * self.head_size),
                ("v", (self.total_num_heads + self.total_num_kv_heads) *
                 self.head_size, self.total_num_kv_heads * self.head_size),
            ]
1152
1153
1154
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)

1155
1156
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
1157
                # Special case for Quantized Weights.
1158
1159
1160
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
1161
1162
                    shard_size = shard_size // param.packed_factor
                    shard_offset = shard_offset // param.packed_factor
1163

1164
                    # Special case for Marlin.
1165
1166
1167
                    shard_size, shard_offset = adjust_marlin_shard(
                        param, shard_size, shard_offset)

1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
                if use_bitsandbytes_4bit:
                    orig_qkv_offsets = {
                        "q": (0, self.total_num_heads * self.head_size),
                        "k": (self.total_num_heads * self.head_size,
                              self.total_num_kv_heads * self.head_size),
                        "v":
                        ((self.total_num_heads + self.total_num_kv_heads) *
                         self.head_size,
                         self.total_num_kv_heads * self.head_size),
                        "total":
                        ((self.total_num_heads + 2 * self.total_num_kv_heads) *
                         self.head_size, 0)
                    }

                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
                        param, orig_qkv_offsets, shard_id)

1185
1186
1187
1188
1189
1190
                loaded_weight_shard = loaded_weight.narrow(
                    output_dim, shard_offset, shard_size)
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id in ["q", "k", "v"]
1191
1192

        # If output dim is defined, use the default loading process.
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
        if output_dim is not None:
            if loaded_shard_id == "q":
                shard_offset = 0
                shard_size = self.num_heads * self.head_size
            elif loaded_shard_id == "k":
                shard_offset = self.num_heads * self.head_size
                shard_size = self.num_kv_heads * self.head_size
            elif loaded_shard_id == "v":
                shard_offset = (self.num_heads +
                                self.num_kv_heads) * self.head_size
                shard_size = self.num_kv_heads * self.head_size
1204
            # Special case for Quantized Weights.
1205
1206
1207
1208
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
1209
1210
                shard_size = shard_size // param.packed_factor
                shard_offset = shard_offset // param.packed_factor
1211

1212
                # Special case for Marlin.
1213
1214
1215
                shard_size, shard_offset = adjust_marlin_shard(
                    param, shard_size, shard_offset)

1216
1217
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)
1218
1219
1220
1221
1222
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

1223
            if use_bitsandbytes_4bit:
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
                orig_qkv_offsets = {
                    "q": (0, self.num_heads * self.head_size),
                    "k": (self.num_heads * self.head_size,
                          self.num_kv_heads * self.head_size),
                    "v":
                    ((self.num_heads + self.num_kv_heads) * self.head_size,
                     self.num_kv_heads * self.head_size),
                    "total":
                    ((self.num_heads + 2 * self.num_kv_heads) * self.head_size,
                     0)
                }
1235
                shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
1236
                    param, orig_qkv_offsets, loaded_shard_id)
gaoqiong's avatar
gaoqiong committed
1237

1238
            if not envs.VLLM_USE_NN or len(param_data.shape)==1 or self.is_quantization:
1239
1240
1241
1242
1243
1244
                param_data = param_data.narrow(output_dim, shard_offset,
                                               shard_size)
            else:
                param_data = param_data.narrow(int(not(output_dim)), shard_offset,
                                               shard_size)
                
zhuwenwen's avatar
zhuwenwen committed
1245
            if loaded_shard_id == "q":
1246
                shard_id = self.tp_rank
1247
            else:
1248
                shard_id = self.tp_rank // self.num_kv_head_replicas
1249
            start_idx = shard_id * shard_size
1250

1251
            if not is_sharded_weight:
1252
1253
1254
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)

1255
1256
1257
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
1258
                param_data, loaded_weight, loaded_shard_id)
1259
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
1260
1261
1262
1263
1264
1265
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "QKVParallelLinear, assume the weight is the same "
                    "for all partitions.")
gaoqiong's avatar
gaoqiong committed
1266

1267
        if envs.VLLM_USE_NN and not self.is_quantization:
1268
1269
            loaded_weight = loaded_weight.t()
            
gaoqiong's avatar
gaoqiong committed
1270
1271
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)
1272
1273


1274
@CustomOp.register("row_parallel_linear")
1275
class RowParallelLinear(LinearBase):
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
    """Linear layer with row parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its first dimension and X along its second dimension as:
               -   -
              | A_1 |
              | .   |
          A = | .   |        X = [X_1, ..., X_p]
              | .   |
              | A_p |
               -   -
    Arguments:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias. Note that bias is not parallelized.
        input_is_parallel: If true, we assume that the input is already
                           split across the GPUs and we do not split
                           again.
        skip_bias_add: This was added to enable performance optimization where
                       bias can be fused with other element-wise operations.
                       We skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
1298
1299
1300
        reduce_results: If true, call all-reduce on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y = X_iA_i
1301
        quant_config: Quantization configure.
1302
1303
1304
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.down_proj)
        return_bias: If true, return bias together with outputs in forward pass.
1305
        disable_tp: If true, weights matrix won't be sharded through tp rank.
1306
1307
    """

1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        input_is_parallel: bool = True,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        reduce_results: bool = True,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        *,
        return_bias: bool = True,
1321
        disable_tp: bool = False,
1322
    ):
1323
        # Divide the weight matrix along the first dimension.
1324
1325
1326
1327
        self.tp_rank = (get_tensor_model_parallel_rank()
                        if not disable_tp else 0)
        self.tp_size = (get_tensor_model_parallel_world_size()
                        if not disable_tp else 1)
1328
1329
1330
1331
        self.input_size_per_partition = divide(input_size, self.tp_size)
        self.output_size_per_partition = output_size
        self.output_partition_sizes = [output_size]

1332
1333
1334
1335
1336
1337
        super().__init__(input_size,
                         output_size,
                         skip_bias_add,
                         params_dtype,
                         quant_config,
                         prefix,
1338
1339
                         return_bias=return_bias,
                         disable_tp=disable_tp)
1340

1341
1342
1343
        self.input_is_parallel = input_is_parallel
        self.reduce_results = reduce_results

1344
        assert self.quant_method is not None
1345
1346
1347
        self.quant_method.create_weights(
            layer=self,
            input_size_per_partition=self.input_size_per_partition,
1348
            output_partition_sizes=self.output_partition_sizes,
1349
1350
1351
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
1352
1353
            weight_loader=(
                self.weight_loader_v2 if self.quant_method.__class__.__name__
1354
                in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
1355
1356
1357
1358
1359
1360
        if not reduce_results and (bias and not skip_bias_add):
            raise ValueError("When not reduce the results, adding bias to the "
                             "results can lead to incorrect results")

        if bias:
            self.bias = Parameter(
1361
                torch.empty(self.output_size, dtype=params_dtype))
1362
1363
1364
1365
1366
1367
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
        else:
            self.register_parameter("bias", None)
1368

1369
        self.update_param_tp_status()
1370
        self.is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
1371
1372
1373

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        input_dim = getattr(param, "input_dim", None)
1374
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1375
1376
1377
1378
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
1379
1380
1381
1382
1383
1384
1385
1386
1387

        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
1388
1389
            weight_shape = list(loaded_weight.shape)
            if input_dim:
1390
1391
                weight_shape[input_dim] = (weight_shape[input_dim] //
                                           self.tp_size)
1392
            param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
1393

1394
        param_data = param.data
1395
        if input_dim is not None and not is_sharded_weight:
1396
            if not envs.VLLM_USE_NN or self.is_quantization:
1397
1398
1399
                shard_size = param_data.shape[input_dim]
            else:
                shard_size = param_data.shape[int(not(input_dim))]
1400
            start_idx = self.tp_rank * shard_size
1401
1402
            loaded_weight = loaded_weight.narrow(input_dim, start_idx,
                                                 shard_size)
1403

1404
1405
1406
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
1407
1408
            loaded_weight = loaded_weight.reshape(1)

1409
        if envs.VLLM_USE_NN and not self.is_quantization:
1410
1411
            loaded_weight = loaded_weight.t()
            
1412
1413
1414
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

1415
1416
    def weight_loader_v2(self, param: BasevLLMParameter,
                         loaded_weight: torch.Tensor):
1417
1418
1419
1420
1421
1422
1423

        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)

1424
        param.load_row_parallel_weight(loaded_weight=loaded_weight, is_quantization=self.is_quantization)
1425

1426
    def forward(
1427
1428
        self,
        input_,
1429
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
1430
1431
1432
1433
1434
        if self.input_is_parallel:
            input_parallel = input_
        else:
            splitted_input = split_tensor_along_last_dim(
                input_, num_partitions=self.tp_size)
1435
            input_parallel = splitted_input[self.tp_rank].contiguous()
1436
1437

        # Matrix multiply.
1438
        assert self.quant_method is not None
1439
1440
1441
        # Only fuse bias add into GEMM for rank 0 (this ensures that
        # bias will not get added more than once in TP>1 case)
        bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
1442
        output_parallel = self.quant_method.apply(self, input_parallel, bias_)
1443

1444
        if self.reduce_results and self.tp_size > 1:
zhuwenwen's avatar
zhuwenwen committed
1445
            output = tensor_model_parallel_all_reduce(output_parallel)
1446
        else:
1447
1448
1449
            output = output_parallel

        output_bias = self.bias if self.skip_bias_add else None
1450

1451
1452
        if not self.return_bias:
            return output
1453
        return output, output_bias
1454
1455

    def extra_repr(self) -> str:
1456
        s = f"in_features={self.input_size_per_partition}"
1457
1458
1459
1460
1461
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={self.tp_size}"
        s += f", reduce_results={self.reduce_results}"
        return s
1462
1463


1464
@CustomOp.register("qkv_cross_parallel_linear")
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
class QKVCrossParallelLinear(LinearBase):
    """Linear layers for efficient cross-attention's QKV transformation.

    Args:
        hidden_size: input hidden state size of the transformer.
        head_size: size of each attention head.
        total_num_heads: total number of attention query heads.
        total_num_kv_heads: total number of attention key/value heads. If
                            None, assume total_num_kv_heads = total_num_heads.
        bias: If true, add bias.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
        quant_config: Quantization configure.
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
    """
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493

    def __init__(self,
                 hidden_size: int,
                 head_size: int,
                 total_num_heads: int,
                 total_num_kv_heads: Optional[int] = None,
                 bias: bool = True,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
        # input_size and output_size are not used, just for alignment
        input_size = hidden_size
        output_size = (total_num_heads + (total_num_kv_heads or 0)) * head_size
        super().__init__(input_size=input_size,
                         output_size=output_size,
                         skip_bias_add=skip_bias_add,
                         params_dtype=params_dtype,
                         quant_config=quant_config,
                         prefix=prefix)

        self.quant_config = quant_config

1506
        # Empty placeholders for loading as a single module.
1507
1508
1509
1510
1511
1512
1513
1514
1515
        placeholder_size = 0
        assert self.quant_method is not None
        self.quant_method.create_weights(self,
                                         placeholder_size, [placeholder_size],
                                         placeholder_size,
                                         placeholder_size,
                                         self.params_dtype,
                                         weight_loader=self.weight_loader)

1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
        # Use a dictionary to avoid submodules parameters auto-registration:
        # drop-in replacement for a `QKVParallelLinear` module.
        self.proj = dict()
        self.proj["q_proj_decoder"] = ColumnParallelLinear(
            input_size=hidden_size,
            output_size=total_num_heads * head_size,
            bias=bias,
            quant_config=quant_config,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            prefix=f"{prefix}.q_proj_decoder")

        self.proj["kv_proj_encoder"] = QKVParallelLinear(
            hidden_size=hidden_size,
            head_size=head_size,
            total_num_heads=0,
            total_num_kv_heads=total_num_kv_heads,
            bias=bias,
            quant_config=quant_config,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            prefix=f"{prefix}.kv_proj_encoder")

        # `kv_proj_encoder.num_kv_heads` accounts for sharding with tp>1.
1540
        self.q_size = self.q_proj_decoder.output_size_per_partition
1541
1542
1543
1544
1545
        self.kv_size = self.kv_proj_encoder.num_kv_heads * head_size

        if bias:
            self.bias = torch.nn.Parameter()
            set_weight_attrs(self.bias, {
1546
                "output_dim": 0,
1547
                "weight_loader": self.weight_loader_v1,
1548
            })
1549
1550
        else:
            self.bias = None
1551

1552
1553
1554
1555
1556
    def process_weights_after_loading(self):
        for layer in self.proj.values():
            if self.quant_method is not None:
                self.quant_method.process_weights_after_loading(layer)

1557
    @property
1558
1559
1560
    def q_proj_decoder(self) -> ColumnParallelLinear:
        layer = self.proj["q_proj_decoder"]
        for name, param in self.named_parameters():
1561
1562
1563
1564
1565
            target_param = getattr(layer, name, None)
            if target_param is not None:
                self.sync_weight_attrs(param,
                                       target_param,
                                       mode="q_proj_decoder")
1566
        return layer
1567
1568

    @property
1569
1570
1571
    def kv_proj_encoder(self) -> QKVParallelLinear:
        layer = self.proj["kv_proj_encoder"]
        for name, param in self.named_parameters():
1572
1573
1574
1575
1576
            target_param = getattr(layer, name, None)
            if target_param is not None:
                self.sync_weight_attrs(param,
                                       target_param,
                                       mode="kv_proj_encoder")
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
        return layer

    def sync_weight_attrs(
        self,
        src_param: nn.Parameter,
        tgt_param: nn.Parameter,
        mode: Literal["q_proj_decoder", "kv_proj_encoder"],
    ):
        missing_attrs_dict = {
            k: getattr(src_param, k)
1587
1588
            for k in (set(vars(src_param).keys()) -
                      set(vars(tgt_param).keys()))
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
        }
        # TODO(Isotr0py): handle bitsandbytes 8bit
        use_bitsandbytes_4bit = getattr(src_param, "use_bitsandbytes_4bit",
                                        False)
        if (missing_attrs_dict and use_bitsandbytes_4bit):
            q_proj_attrs, kv_proj_attrs = left_shift_bitsandbytes_4bit_shard(
                missing_attrs_dict)
            if mode == "q_proj_decoder":
                set_weight_attrs(tgt_param, q_proj_attrs)
            elif mode == "kv_proj_encoder":
                set_weight_attrs(tgt_param, kv_proj_attrs)
        else:
            set_weight_attrs(tgt_param, missing_attrs_dict)
1602

1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
    def _is_same_param(
        self,
        src_param: torch.nn.Parameter,
        map_param: torch.nn.Parameter,
    ) -> bool:
        """Check if two parameters are exactly pointing to same things."""
        # ignore weight_loader because it's always different
        key_to_ignore = ["weight_loader", "_weight_loader"]
        has_same_type_name = type(src_param) is type(map_param)
        src_param_attrs = {
            k: v
            for k, v in src_param.__dict__.items() if k not in key_to_ignore
        }
        map_param_attrs = {
            k: v
            for k, v in map_param.__dict__.items() if k not in key_to_ignore
        }
        has_same_attrs = src_param_attrs == map_param_attrs
        return has_same_type_name and has_same_attrs

    def select_proj_params(
        self,
        layer: nn.Module,
        param: nn.Parameter,
    ) -> nn.Parameter:
        """
        Given the placeholder param, 
        return the corresponding param in the proj layers.
        """
        target_param_list = [
            v for _, v in layer.named_parameters()
            if self._is_same_param(param, v)
        ]
        assert len(target_param_list) == 1
        target_param = target_param_list[0]
        return target_param

    def forward(  # type: ignore[override]
        self,
        decoder_hidden_states: torch.Tensor,
        encoder_hidden_states: torch.Tensor,
    ) -> tuple[torch.Tensor, ...]:
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
        q, _ = self.q_proj_decoder(decoder_hidden_states)
        if encoder_hidden_states is None:
            # Encoder KV already cached.
            k = None
            v = None
        else:
            # Prefill phase, encoder KV cached here.
            kv_enc, _ = self.kv_proj_encoder(encoder_hidden_states)
            # Split kv in half
            k, v = kv_enc.split(self.kv_size, dim=-1)
        return q, k, v

1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
    def weight_loader_v1(self,
                         param: torch.nn.Parameter,
                         loaded_weight: torch.Tensor,
                         loaded_shard_id: Optional[str] = None):
        # just like all other parameters, does not yet
        # support loading bias with weight_loader_v2
        layer = (self.q_proj_decoder
                 if loaded_shard_id == "q" else self.kv_proj_encoder)
        target_param = self.select_proj_params(layer, param)
        shard_id_args = (loaded_shard_id, ) if loaded_shard_id != "q" else ()
        layer.weight_loader(target_param, loaded_weight, *shard_id_args)

1669
1670
1671
1672
1673
1674
1675
1676
    def weight_loader(self,
                      param: torch.nn.Parameter,
                      loaded_weight: torch.Tensor,
                      loaded_shard_id: Optional[str] = None):
        layer = (self.q_proj_decoder
                 if loaded_shard_id == "q" else self.kv_proj_encoder)
        target_param = self.select_proj_params(layer, param)
        shard_id_args = (loaded_shard_id, ) if loaded_shard_id != "q" else ()
1677
1678
1679
1680
        if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED:
            layer.weight_loader_v2(target_param, loaded_weight, *shard_id_args)
        else:
            layer.weight_loader(target_param, loaded_weight, *shard_id_args)
1681
1682
1683

    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
1684
        s += f", q_size={self.q_size}"
1685
1686
1687
1688
        s += f", kv_size={self.kv_size}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={get_tensor_model_parallel_world_size()}"
        s += ", gather_output=False"
1689
        return s