linear.py 88.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
from abc import abstractmethod
zhuwenwen's avatar
zhuwenwen committed
6
from typing import Any, Literal, Optional, Union, List
7
import vllm.envs as envs
8
import torch
9
import torch.nn as nn
10
from torch.nn.parameter import Parameter, UninitializedParameter
11

12
from vllm import envs
13
14
15
16
17
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size,
                              split_tensor_along_last_dim,
                              tensor_model_parallel_all_gather,
                              tensor_model_parallel_all_reduce)
18
from vllm.logger import init_logger
19
20
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig, QuantizeMethodBase)
21
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
22
# yapf: disable
23
from vllm.model_executor.parameter import (BasevLLMParameter,
24
                                           BlockQuantScaleParameter,
25
                                           PackedColumnParameter,
26
                                           PackedvLLMParameter,
27
28
                                           PerTensorScaleParameter,
                                           RowvLLMParameter)
29
# yapf: enable
30
from vllm.model_executor.utils import set_weight_attrs
31
from vllm.platforms import current_platform
gaoqiong's avatar
gaoqiong committed
32

zhuwenwen's avatar
zhuwenwen committed
33
import os
34
from vllm.model_executor.utils import gemm_bank_conf
35

36
37
38
39
40
if envs.USE_FUSED_RMS_QUANT:
    try:
        from lmslim.quantize.quant_ops import lm_faster_rmsquant
    except Exception as e:
        print(f"Error: Import fused rmsquant error: {e}") 
41
42
43
44
45
46
47
if envs.USE_FUSED_SILU_MUL_QUANT:        
    try:
        # from lightop import fuse_silu_mul_quant
        from lmslim.quantize.quant_ops import lm_fuse_silu_mul_quant
    except Exception as e:
        print(f"Error: Import fused silu_mul_qunat error: {e}")
        
48
49
logger = init_logger(__name__)

50
WEIGHT_LOADER_V2_SUPPORTED = [
51
    "CompressedTensorsLinearMethod",
52
53
    "BitBLASLinearMethod",
    "GPTQBitBLASLinearMethod",
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
    "AWQMarlinLinearMethod",
    "AWQLinearMethod",
    "GPTQMarlinLinearMethod",
    "Fp8LinearMethod",
    "MarlinLinearMethod",
    "QQQLinearMethod",
    "GPTQMarlin24LinearMethod",
    "TPUInt8LinearMethod",
    "GPTQLinearMethod",
    "FBGEMMFp8LinearMethod",
    "ModelOptFp8LinearMethod",
    "IPEXAWQLinearMethod",
    "IPEXGPTQLinearMethod",
    "HQQMarlinMethod",
    "QuarkLinearMethod",
    "ModelOptNvFp4LinearMethod",
zhuwenwen's avatar
zhuwenwen committed
70
    "BlockInt8LinearMethod",
71
]
72

73

74
75
76
77
78
79
80
81
82
def adjust_bitblas_shard(param, shard_size, shard_offset):
    bitblas_tile_size = getattr(param, "bitblas_tile_size", None)
    if bitblas_tile_size is not None:
        return (shard_size // bitblas_tile_size,
                shard_offset // bitblas_tile_size)

    return shard_size, shard_offset


83
84
85
86
87
88
89
90
def adjust_marlin_shard(param, shard_size, shard_offset):
    marlin_tile_size = getattr(param, "marlin_tile_size", None)
    if marlin_tile_size is None:
        return shard_size, shard_offset

    return shard_size * marlin_tile_size, shard_offset * marlin_tile_size


91
def adjust_bitsandbytes_4bit_shard(param: Parameter,
92
93
                                   shard_offsets: dict[str, tuple[int, int]],
                                   loaded_shard_id: str) -> tuple[int, int]:
94
95
    """Adjust the quantization offsets and sizes for BitsAndBytes sharding."""

96
97
    total, _ = shard_offsets["total"]
    orig_offset, orig_size = shard_offsets[loaded_shard_id]
98
99
100
101
102
103
104
105

    quantized_total = param.data.shape[0]
    quantized_offset = orig_offset * quantized_total // total
    quantized_size = orig_size * quantized_total // total

    return quantized_size, quantized_offset


106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
    """For fused modules (QKV and MLP) we have an array of length
    N that holds 1 scale for each "logical" matrix. So the param
    is an array of length N. The loaded_weight corresponds to 
    one of the shards on disk. Here, we slice the param based on 
    the shard_id for loading.
    """
    qkv_idxs = {"q": 0, "k": 1, "v": 2}

    if isinstance(shard_id, str):
        shard_id = qkv_idxs[shard_id]
    elif not isinstance(shard_id, int):
        raise ValueError(f"Unknown Shard Id {shard_id}")

    # AutoFP8 scales do not have a shape
    # compressed-tensors scales do have a shape
    if len(loaded_weight.shape) != 0:
        assert loaded_weight.shape[0] == 1
        loaded_weight = loaded_weight[0]

126
127
128
129
    if envs.VLLM_USE_NN:
        return param[shard_id], loaded_weight.t()
    else:
        return param[shard_id], loaded_weight
130
131


132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
# TODO(Isotr0py): We might need a more flexible structure to handle
# bitsandbytes shard offsets.
def left_shift_bitsandbytes_4bit_shard(bnb_weight_attrs: dict[str, Any]):
    """
    Separate the BitsAndBytes 4-bit shard.

    For example, given bnb weight attributes as below:
    {
        'bnb_shard_offsets': array([0, 4, 8, 16]), 
        'bnb_quant_state': {0: ..., 1: ..., 2: ...},
    }

    The function will return:
    {
        'bnb_shard_offsets': array([0, 4]), 
        'bnb_quant_state': {0: ...},
    }
    and
    {
        'bnb_shard_offsets': array([0, 4, 12]),
        'bnb_quant_state': {0: ..., 1: ...},
    }
    """
    shard_offsets = bnb_weight_attrs["bnb_shard_offsets"]
    offset_l = shard_offsets[:2]
    offset_r = shard_offsets[1:] - shard_offsets[1]
    quant_state_l = {0: bnb_weight_attrs["bnb_quant_state"][0]}
    quant_state_r = {
        i - 1: bnb_weight_attrs["bnb_quant_state"][i]
        for i in range(1,
                       len(shard_offsets) - 1)
    }
    left = dict(bnb_shard_offsets=offset_l, bnb_quant_state=quant_state_l)
    right = dict(bnb_shard_offsets=offset_r, bnb_quant_state=quant_state_r)
    return left, right


169
class LinearMethodBase(QuantizeMethodBase):
170
171
172
    """Base class for different (maybe quantized) linear methods."""

    @abstractmethod
173
174
    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
175
                       output_partition_sizes: list[int], input_size: int,
176
177
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
178
179
        """Create weights for a linear layer. 
           The weights will be set as attributes of the layer.
180

181
182
183
184
185
186
187
188
189
190
        Args:
            layer: The layer that is using the LinearMethodBase factory.
            input_size_per_partition: Size of the weight input dim on rank X.
            output_partition_sizes: Sizes of the output dim of each logical 
                weight on rank X. E.g., output_partition_sizes for QKVLinear
                is a list contains the width of Wq, Wk, Wv on rank X.
            input_size: Size of the input dim of the weight across all ranks.
            output_size: Size of the output dim of the weight across all ranks.
            params_dtype: Datatype of the parameters.
        """
191
192
193
        raise NotImplementedError

    @abstractmethod
194
195
196
197
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
198
199
        """Apply the weights in layer to the input tensor.
        Expects create_weights to have been called before on the layer."""
200
201
202
203
        raise NotImplementedError


class UnquantizedLinearMethod(LinearMethodBase):
204
    """Linear method without quantization."""
205
206
    
    def __init__(self):
zhuwenwen's avatar
zhuwenwen committed
207
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
208
        self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
gaoqiong's avatar
gaoqiong committed
209
        
210
211
    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
212
                       output_partition_sizes: list[int], input_size: int,
213
214
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
215
216
217
        if envs.VLLM_USE_NN:
            weight = Parameter(torch.empty(input_size_per_partition,
                                       sum(output_partition_sizes),
218
219
                                       dtype=params_dtype),
                           requires_grad=False)
220
221
222
223
224
        else:
            weight = Parameter(torch.empty(sum(output_partition_sizes),
                                           input_size_per_partition,
                                           dtype=params_dtype),
                           requires_grad=False)
225
        set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
226
227
        layer.register_parameter("weight", weight)
        set_weight_attrs(weight, extra_weight_attrs)
228

229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
    def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
        if current_platform.is_cpu() and envs.VLLM_CPU_SGL_KERNEL:
            N, K = layer.weight.size()
            dtype = layer.weight.dtype
            if (torch._C._cpu._is_amx_tile_supported()
                    and dtype == torch.bfloat16 and N % 32 == 0
                    and K % 32 == 0):
                packed_weight = torch.ops._C.convert_weight_packed(
                    layer.weight)
                assert packed_weight.size() == layer.weight.size()
                layer.weight.copy_(packed_weight)
                if layer.bias is not None:
                    layer.bias = Parameter(layer.bias.to(torch.float32),
                                           requires_grad=False)
                layer.use_cpu_sgl = True
            else:
                logger.warning(
                    "CPU SGL kernels require Intel AMX support,"
                    " bfloat16 weight, IC and OC are divisible by 32.")
                layer.use_cpu_sgl = False

250
251
252
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
253
254
255
              bias: Optional[torch.Tensor] = None,
              residual: Optional[torch.Tensor] = None,
              output: Optional[torch.Tensor] = None) -> torch.Tensor:
zhuwenwen's avatar
zhuwenwen committed
256
        if self.use_llama_nn:
257
258
            if gemm_bank_conf(layer.weight.shape[1] - 32) and os.environ['GEMM_PAD'] == '1':
                layer.weight = layer.weight[:,:-32]
259
                
zhuwenwen's avatar
zhuwenwen committed
260
            if bias is not None:
zhuwenwen's avatar
zhuwenwen committed
261
                if len(x.shape) == 2: 
262
                    return torch.addmm(bias, x, layer.weight)
zhuwenwen's avatar
zhuwenwen committed
263
                else:
264
                    return torch.matmul(x, layer.weight) + bias
zhuwenwen's avatar
zhuwenwen committed
265
            else:
266
                return torch.matmul(x, layer.weight)
zhuwenwen's avatar
zhuwenwen committed
267
        else:
268
269
270
            if envs.VLLM_USE_NN and x.shape[-1] == layer.weight.shape[0]:
                return dispatch_unquantized_gemm()(x, layer.weight.t(), bias)
            else:
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
                weight = layer.weight
                if residual is not None:
                    assert output is None or output is residual
                    if get_tensor_model_parallel_world_size(
                    ) > 1 and get_tensor_model_parallel_rank() != 0:
                        beta = 0.0
                    else:
                        beta = 1.0
                    # optimize cuda memory usage
                    if x.dim() == 2:
                        torch.addmm(residual, x, weight.t(), beta=beta, out=residual)
                    elif x.dim() >= 3:
                        hx = x.size(-1)
                        hr = residual.size(-1)
                        torch.addmm(residual.view(-1, hr),
                                    x.view(-1, hx),
                                    weight.t(),
                                    beta=beta,
                                    out=residual.view(-1, hr))
                    else:
                        raise AssertionError(
                            "unrecognized tensor dimensions: {}".format(x.dim()))
                    if bias is not None:
                        residual += bias
                    return residual
                else:
                    if output is not None:
                        if bias is not None:  # always separate bias add when output is provided
                            torch.matmul(x, weight.t(), out=output)
                            output.add_(bias)
                            return output
                        return torch.matmul(x, weight.t(), out=output)
                    else:
                        return dispatch_unquantized_gemm()(x, layer.weight, bias)
                # return dispatch_unquantized_gemm()(x, layer.weight, bias)
306

307

zhuwenwen's avatar
zhuwenwen committed
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
class UnquantizedMoELinearMethod(LinearMethodBase):
    """MoE Linear method without quantization.
    """

    def __init__(self):
        self.quant_config = None

    def create_weights(self,
                       layer: torch.nn.Module,
                       input_size_per_partition: int,
                       output_partition_sizes: List[int],
                       input_size: int,
                       output_size: int,
                       params_dtype: torch.dtype,
                       num_experts: Optional[int] = None,
                       **extra_weight_attrs):
        weight = Parameter(torch.empty(num_experts,
                                       sum(output_partition_sizes),
                                       input_size_per_partition,
                                       device=torch.cuda.current_device(),
                                       dtype=params_dtype),
                           requires_grad=False)
        set_weight_attrs(weight, {"input_dim": 2, "output_dim": 1})
        layer.register_parameter("weight", weight)
        set_weight_attrs(weight, extra_weight_attrs)

    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
        """Apply the weights to the input tensor."""
        raise NotImplementedError
    
    
342
343
class LinearBase(torch.nn.Module):
    """Base linear layer.
344
345
346
347
348
349
350

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        bias: If true, add bias.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
351
        quant_config: Quantization configure.
352
        return_bias: If true, return bias together with outputs in forward pass.
353
354
355
356
357
358
359
360
    """

    def __init__(
        self,
        input_size: int,
        output_size: int,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
361
        quant_config: Optional[QuantizationConfig] = None,
362
        prefix: str = "",
363
364
        *,
        return_bias: bool = True,
365
366
367
368
369
370
371
372
373
374
    ):
        super().__init__()

        # Keep input parameters
        self.input_size = input_size
        self.output_size = output_size
        self.skip_bias_add = skip_bias_add
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype
375
        if quant_config is None:
376
377
            self.quant_method: Optional[
                QuantizeMethodBase] = UnquantizedLinearMethod()
378
        else:
379
380
            self.quant_method = quant_config.get_quant_method(self,
                                                              prefix=prefix)
381
        self.return_bias = return_bias
382

383
384
385
    def forward(
        self, x: torch.Tensor
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
386
387
388
389
390
391
392
393
394
395
396
397
398
        raise NotImplementedError


class ReplicatedLinear(LinearBase):
    """Replicated linear layer.

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        bias: If true, add bias.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
        quant_config: Quantization configure.
399
400
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
401
        return_bias: If true, return bias together with outputs in forward pass.
402
403
    """

404
405
406
407
408
409
410
411
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
412
        eps: Optional[float] = 1e-6,
413
414
415
416
        prefix: str = "",
        *,
        return_bias: bool = True,
    ):
417
418
419
420
421
        super().__init__(input_size,
                         output_size,
                         skip_bias_add,
                         params_dtype,
                         quant_config,
422
423
                         prefix=prefix,
                         return_bias=return_bias)
424
        self.eps = eps
425

426
427
        # All the linear layer supports quant method.
        assert self.quant_method is not None
428
429
430
431
432
        self.quant_method.create_weights(self,
                                         self.input_size, [self.output_size],
                                         self.input_size,
                                         self.output_size,
                                         self.params_dtype,
433
                                         weight_loader=self.weight_loader)
434

435
436
        if bias:
            self.bias = Parameter(
437
                torch.empty(self.output_size, dtype=self.params_dtype))
438
439
440
441
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
442
443
444
        else:
            self.register_parameter("bias", None)

445
446
447
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        # If the weight on disk does not have a shape, give it one
        # (such scales for AutoFp8).
448
449
450
451
452
453
454
455
456
457
458
        # Special case for GGUF

        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
            param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)

459
460
461
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)

462
463
464
465
        is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
        if envs.VLLM_USE_NN and not is_quantization:
            loaded_weight = loaded_weight.t()
            
466
467
468
        assert param.size() == loaded_weight.size(), (
            f"Tried to load weights of size {loaded_weight.size()}"
            f"to a parameter of size {param.size()}")
469
470
        param.data.copy_(loaded_weight)

471
    def forward(
472
473
474
475
476
477
        self, 
        input_: torch.Tensor,
        rms_weight: Optional[torch.Tensor] = None,
        residual: Optional[torch.Tensor] = None,
        quant_args: Optional[list] = None,
        update_hd: Optional[bool] = True
478
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
        if envs.USE_FUSED_RMS_QUANT and (rms_weight is not None or quant_args is not None):
            if quant_args is not None:
                input_quant_args = quant_args
            
                bias = self.bias if not self.skip_bias_add else None
                assert self.quant_method is not None
                output = self.quant_method.apply(self, input_, bias, input_quant_args)
                output_bias = self.bias if self.skip_bias_add else None
                if not self.return_bias:
                    return output
                return output, output_bias

            else:
                i_q, _scales = lm_faster_rmsquant(input=input_,
                                                  rms_weight=rms_weight,
                                                  epsilon=self.eps,
                                                  quant_dtype=torch.int8,
                                                  residual=residual,
                                                  update_input=update_hd
                                                  )
            
                new_residual = residual
                input_quant_args = [i_q, _scales]
                
                bias = self.bias if not self.skip_bias_add else None
                assert self.quant_method is not None
                output = self.quant_method.apply(self, input_, bias, input_quant_args)
                output_bias = self.bias if self.skip_bias_add else None
                if not self.return_bias:
                    return output
                return output, new_residual, output_bias, input_quant_args

        else:
            bias = self.bias if not self.skip_bias_add else None
            assert self.quant_method is not None
            output = self.quant_method.apply(self, input_, bias)
            output_bias = self.bias if self.skip_bias_add else None
            if not self.return_bias:
                return output
            return output, output_bias
519

520
521
522
523
524
525
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        return s

526

527
class ColumnParallelLinear(LinearBase):
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
    """Linear layer with column parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its second dimension as A = [A_1, ..., A_p].

    Args:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y_i = XA_i
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
544
        quant_config: Quantization configure.
James Fleming's avatar
James Fleming committed
545
546
        output_sizes: list of output sizes packed into one output, like for QKV
                       the list would be size 3.
547
548
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj) 
549
550
    """

551
552
553
554
555
556
557
558
559
560
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
        output_sizes: Optional[list[int]] = None,
561
        eps: Optional[float] = 1e-6,
562
563
564
565
        prefix: str = "",
        *,
        return_bias: bool = True,
    ):
566
        # Divide the weight matrix along the last dimension.
567
568
569
        self.tp_size = get_tensor_model_parallel_world_size()
        self.input_size_per_partition = input_size
        self.output_size_per_partition = divide(output_size, self.tp_size)
570
571
572
573
        self.output_partition_sizes = [self.output_size_per_partition]
        # If QKV or MergedColumn, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = [
574
                divide(output_size, self.tp_size)
575
576
577
                for output_size in self.output_sizes
            ]

578
579
580
581
582
583
584
        super().__init__(input_size,
                         output_size,
                         skip_bias_add,
                         params_dtype,
                         quant_config,
                         prefix,
                         return_bias=return_bias)
585
        self.eps = eps
586
587
        self.gather_output = gather_output

James Fleming's avatar
James Fleming committed
588
589
        if output_sizes is None:
            output_sizes = [output_size]
590

591
        assert self.quant_method is not None
592
593
        self.quant_method.create_weights(
            layer=self,
594
            input_size_per_partition=self.input_size_per_partition,
595
596
597
598
            output_partition_sizes=self.output_partition_sizes,
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
599
600
            weight_loader=(
                self.weight_loader_v2 if self.quant_method.__class__.__name__
601
                in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
602
603
604
605
606
607
608
609
610
611
612
613
614
615
        if bias:
            self.bias = Parameter(
                torch.empty(self.output_size_per_partition,
                            dtype=params_dtype))
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
        else:
            self.register_parameter("bias", None)

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        tp_rank = get_tensor_model_parallel_rank()
        output_dim = getattr(param, "output_dim", None)
616

617
618
619
620
621
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
622
        is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
623

624
625
626
627
628
629
630
631
        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
632
633
634
635
636
637
            final_shape = list(loaded_weight.shape)
            if output_dim is not None:
                tp_size = get_tensor_model_parallel_world_size()
                assert final_shape[output_dim] % tp_size == 0
                final_shape[output_dim] = final_shape[output_dim] // tp_size
            param.materialize(final_shape, dtype=loaded_weight.dtype)
638

639
        param_data = param.data
640
        if output_dim is not None and not is_sharded_weight:
641
642
643
644
            if not envs.VLLM_USE_NN or len(param_data.shape)==1 or is_quantization:
                shard_size = param_data.shape[output_dim] 
            else:
                shard_size = param_data.shape[int(not(output_dim))]
645
646
647
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                 shard_size)
648
649
650
651
652

        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)
653

654
655
656
        if envs.VLLM_USE_NN and not is_quantization:
            loaded_weight = loaded_weight.t()
            
657
658
659
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

660
    def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor):
661
662
663
664
665
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)
666
667
        param.load_column_parallel_weight(loaded_weight=loaded_weight)

668
    def forward(
669
670
671
        self, input_,
        rms_weight: Optional[torch.Tensor] = None,
        residual: Optional[torch.Tensor] = None,
672
673
        update_hd: Optional[bool] = True,
        output: Optional[torch.Tensor] = None
674
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
675
676
677
678
679
680
681
682
683
684
685
686
687
        if envs.USE_FUSED_RMS_QUANT and rms_weight is not None:
            input_quant_args = None
            assert rms_weight is not None 
            i_q, _scales = lm_faster_rmsquant(input=input_,
                                        rms_weight=rms_weight,
                                        epsilon=self.eps,
                                        quant_dtype=torch.int8,
                                        residual=residual,
                                        update_input=update_hd)
            new_residual = residual
            input_quant_args = [i_q, _scales]
        
            bias = self.bias if not self.skip_bias_add else None
688

689
690
691
692
693
694
695
696
697
698
            assert self.quant_method is not None
            output_parallel = self.quant_method.apply(self, input_, bias, input_quant_args)
            if self.gather_output:
                output = tensor_model_parallel_all_gather(output_parallel)
            else:
                output = output_parallel
            output_bias = self.bias if self.skip_bias_add else None
            if not self.return_bias:
                return output
            return output, new_residual, output_bias
699
        else:
700
701
702
            bias = self.bias if not self.skip_bias_add else None
            # Matrix multiply.
            assert self.quant_method is not None
703
            output_parallel = self.quant_method.apply(self, input_, bias, output=output)
704
705
706
707
708
709
710
711
712
            if self.gather_output:
                # All-gather across the partitions.
                output = tensor_model_parallel_all_gather(output_parallel)
            else:
                output = output_parallel
            output_bias = self.bias if self.skip_bias_add else None
            if not self.return_bias:
                return output
            return output, output_bias
713

714
715
716
717
718
719
720
721
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size_per_partition}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={get_tensor_model_parallel_world_size()}"
        s += f", gather_output={self.gather_output}"
        return s

722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740

class MergedColumnParallelLinear(ColumnParallelLinear):
    """Packed linear layers with column parallelism.

    Similar to ColumnParallelLinear, but the weight matrix is concatenated
    along the output dimension. When the weight matrix is loaded, the
    different partitions are sharded separately.

    Args:
        input_size: input dimension of the linear layer.
        output_sizes: list of output dimensions of the linear layer.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make the output
                       available to all GPUs, otherwise, every GPU will have
                       its own output.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
741
        quant_config: Quantization configure.
742
743
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
744
        return_bias: If true, return bias together with outputs in forward pass.
745
746
    """

747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
    def forward(
        self, input_,
        rms_weight: Optional[torch.Tensor] = None,
        residual: Optional[torch.Tensor] = None,
        update_hd: Optional[bool] = True
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
        if envs.USE_FUSED_RMS_QUANT and rms_weight is not None:
            input_quant_args = None
            assert residual is not None and rms_weight is not None 
            i_q, _scales = lm_faster_rmsquant(input=input_,
                                        rms_weight=rms_weight,
                                        epsilon=self.eps,
                                        quant_dtype=torch.int8,
                                        residual=residual,
                                        update_input=update_hd)
            
            new_residual = residual
            input_quant_args = [i_q, _scales]
            
            
            bias = self.bias if not self.skip_bias_add else None
            assert self.quant_method is not None
            output_parallel = self.quant_method.apply(self, input_, bias, input_quant_args)
            
            if self.gather_output:
                # All-gather across the partitions.
                output = tensor_model_parallel_all_gather(output_parallel)
            else:
                output = output_parallel
            output_bias = self.bias if self.skip_bias_add else None
            if not self.return_bias:
                return output
            return output, new_residual, output_bias
        else: # not USE_FUSED_RMS_QUANT
            bias = self.bias if not self.skip_bias_add else None

            assert self.quant_method is not None
            output_parallel = self.quant_method.apply(self, input_, bias)
            if self.gather_output:
                # All-gather across the partitions.
                output = tensor_model_parallel_all_gather(output_parallel)
            else:
                output = output_parallel
            output_bias = self.bias if self.skip_bias_add else None
            if not self.return_bias:
                return output
            return output, output_bias
    
795
796
797
798
799
800
801
802
803
    def __init__(
        self,
        input_size: int,
        output_sizes: list[int],
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
804
        eps: Optional[float] = 1e-6,
805
806
807
808
        prefix: str = "",
        *,
        return_bias: bool = True,
    ):
809
        self.eps = eps
810
811
812
        self.output_sizes = output_sizes
        tp_size = get_tensor_model_parallel_world_size()
        assert all(output_size % tp_size == 0 for output_size in output_sizes)
813
814
815
816
817
818
        super().__init__(input_size=input_size,
                         output_size=sum(output_sizes),
                         bias=bias,
                         gather_output=gather_output,
                         skip_bias_add=skip_bias_add,
                         params_dtype=params_dtype,
819
                         quant_config=quant_config,
820
821
                         prefix=prefix,
                         return_bias=return_bias)
822
823
824
825
826

    def weight_loader(self,
                      param: Parameter,
                      loaded_weight: torch.Tensor,
                      loaded_shard_id: Optional[int] = None):
James Fleming's avatar
James Fleming committed
827

828
829
830
831
832
        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
833
834
835
836
837
838
839
840
            if loaded_shard_id is not None:
                param.data[loaded_shard_id].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
                param.shard_weight_type = {
                    i: loaded_weight.item()
                    for i, _ in enumerate(self.output_sizes)
                }
841
842
            return

843
844
845
        if is_gguf_weight:
            tp_size = get_tensor_model_parallel_world_size()
            tp_rank = get_tensor_model_parallel_rank()
846

847
848
849
            output_dim = getattr(param, "output_dim", None)
            shard_size = loaded_weight.size(output_dim) // tp_size
            start_idx = tp_rank * shard_size
850

851
852
853
854
855
856
            if loaded_shard_id is not None:
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
zhuwenwen's avatar
zhuwenwen committed
857
858
                if len(param.data_container) == 2:
                    self.qweight = param.materialize_nested()
859
                return
860

861
862
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
863
        # Special case for AQLM codebooks.
James Fleming's avatar
James Fleming committed
864
        is_metadata = getattr(param, "is_metadata", False)
865
866
        # Special case for per-tensor scale to load scalar into fused array.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
867
        is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
868

869
        if loaded_shard_id is None:
870
871
            # Loaded weight is already fused on disk (mlp).
            # (e.g., Phi-3's gate_up_proj).
872
            if output_dim is None:
873
                if needs_scalar_to_array:
874
875
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
                        param_data, loaded_weight, 0)
876

877
878
879
880
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            current_shard_offset = 0
881
882
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)
883
            shard_offsets: list[tuple[int, int, int]] = []
884
885
886
887
888
            for i, output_size in enumerate(self.output_sizes):
                shard_offsets.append((i, current_shard_offset, output_size))
                current_shard_offset += output_size
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
889
                # Special case for Quantization.
890
891
892
893
894
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
                    shard_size = shard_size // param.pack_factor
                    shard_offset = shard_offset // param.pack_factor
895
                    # Special case for Marlin.
896
897
898
                    shard_size, shard_offset = adjust_marlin_shard(
                        param, shard_size, shard_offset)

899
900
901
                shard_size, shard_offset = adjust_bitblas_shard(
                    param, shard_size, shard_offset)

902
                if use_bitsandbytes_4bit:
903
904
905
906
907
908
909
910
911
                    index = list(itertools.accumulate([0] + self.output_sizes))
                    orig_offsets = {
                        str(i): (index[i], size)
                        for i, size in enumerate(self.output_sizes)
                    }
                    orig_offsets["total"] = (self.output_size, 0)
                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
                        param, orig_offsets, str(shard_id))

912
913
914
915
916
917
918
919
920
921
922
                loaded_weight_shard = loaded_weight.narrow(
                    output_dim, shard_offset, shard_size)
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id < len(self.output_sizes)
        tp_rank = get_tensor_model_parallel_rank()
        tp_size = get_tensor_model_parallel_world_size()
        if output_dim is not None:
            shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
            shard_size = self.output_sizes[loaded_shard_id] // tp_size
923
            # Special case for quantization.
924
925
926
927
928
929
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
                shard_size = shard_size // param.pack_factor
                shard_offset = shard_offset // param.pack_factor
930
                # Special case for Marlin.
931
932
                shard_size, shard_offset = adjust_marlin_shard(
                    param, shard_size, shard_offset)
933
934
            shard_size, shard_offset = adjust_bitblas_shard(
                param, shard_size, shard_offset)
gaoqiong's avatar
gaoqiong committed
935

936
937
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)
938
939
940
941
942
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

943
            if use_bitsandbytes_4bit:
944
945
946
                shard_size = loaded_weight.shape[output_dim]
                shard_offset = loaded_weight.shape[output_dim] * \
                    loaded_shard_id
947
948
949
950
951
                    
            if not envs.VLLM_USE_NN or is_quantization:
                param_data = param_data.narrow(output_dim, shard_offset, shard_size)
            else:
                param_data = param_data.narrow(int(not(output_dim)), shard_offset, shard_size)
952

953
            start_idx = tp_rank * shard_size
954
            if not is_sharded_weight:
955
956
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)
957
        # Special case for AQLM codebooks.
James Fleming's avatar
James Fleming committed
958
959
960
961
962
        elif is_metadata:
            # metadata indicates fixed size concatenated along dim 0
            shard_size = loaded_weight.shape[0]
            shard_offset = loaded_shard_id * shard_size
            param_data = param_data.narrow(0, shard_offset, shard_size)
963

964
965
966
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
967
968
                param_data, loaded_weight, loaded_shard_id)

969
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
970
971
972
973
974
975
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "MergedColumnParallelLinear, assume the weight is "
                    "the same for all partitions.")
976

977
978
979
        if envs.VLLM_USE_NN and not is_quantization:
            loaded_weight = loaded_weight.t()
            
gaoqiong's avatar
gaoqiong committed
980
981
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)
982

983
984
985
986
987
988
989
990
991
992
993
994
995
    def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
                                           loaded_weight: torch.Tensor):
        """
        Handle special case for models where MLP layers are already
        fused on disk. In this case, we have no shard id. This function
        determmines the shard id by splitting these layers and then calls
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """

        current_shard_offset = 0
996
        shard_offsets: list[tuple[int, int, int]] = []
997
998
999
1000
1001
1002
1003
1004
        for i, output_size in enumerate(self.output_sizes):
            shard_offsets.append((i, current_shard_offset, output_size))
            current_shard_offset += output_size

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
1005
1006
            if isinstance(param, (PackedColumnParameter, PackedvLLMParameter
                                  )) and param.packed_dim == param.output_dim:
1007
1008
                shard_size, shard_offset = \
                    param.adjust_shard_indexes_for_packing(
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
                    shard_size=shard_size, shard_offset=shard_offset)

            loaded_weight_shard = loaded_weight.narrow(param.output_dim,
                                                       shard_offset,
                                                       shard_size)
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

    def weight_loader_v2(self,
                         param: BasevLLMParameter,
                         loaded_weight: torch.Tensor,
                         loaded_shard_id: Optional[int] = None):
        if loaded_shard_id is None:
1021
1022
1023
1024
            if isinstance(param, PerTensorScaleParameter):
                param.load_merged_column_weight(loaded_weight=loaded_weight,
                                                shard_id=0)
                return
1025
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
1026
                param.load_merged_column_weight(loaded_weight=loaded_weight)
1027
                return
1028
            # TODO: @dsikka - move to parameter.py
1029
1030
1031
1032
1033
1034
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id < len(self.output_sizes)

        tp_size = get_tensor_model_parallel_world_size()
1035
1036
1037
1038

        if isinstance(param, BlockQuantScaleParameter):
            from vllm.model_executor.layers.quantization.fp8 import (
                Fp8LinearMethod, Fp8MoEMethod)
1039
1040
1041
            
            from vllm.model_executor.layers.quantization.blockwise_int8 import (
                BlockInt8LinearMethod, BlockInt8MoEMethod)
1042
1043
            assert self.quant_method is not None
            assert isinstance(self.quant_method,
1044
                              (Fp8LinearMethod, Fp8MoEMethod, BlockInt8LinearMethod, BlockInt8MoEMethod))
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
            weight_block_size = self.quant_method.quant_config.weight_block_size
            assert weight_block_size is not None
            block_n, _ = weight_block_size[0], weight_block_size[1]
            shard_offset = (
                (sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) //
                block_n) // tp_size
            shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) //
                          block_n // tp_size)
        else:
            shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
            shard_size = self.output_sizes[loaded_shard_id] // tp_size
1056
1057
1058
1059
1060
1061

        param.load_merged_column_weight(loaded_weight=loaded_weight,
                                        shard_id=loaded_shard_id,
                                        shard_offset=shard_offset,
                                        shard_size=shard_size)

zhuwenwen's avatar
zhuwenwen committed
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230

class MergedColumnParallelMoELinear(MergedColumnParallelLinear):

    def __init__(self,
                 num_experts: int,
                 input_size: int,
                 output_sizes: List[int],
                 params_dtype: Optional[torch.dtype] = None,
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
        torch.nn.Module.__init__(self)
        output_size = sum(output_sizes)
        self.num_experts = num_experts
        self.output_sizes = output_sizes
        self.input_size = input_size
        self.output_size = sum(output_sizes)
        tp_size = get_tensor_model_parallel_world_size()
        assert all(output_size % tp_size == 0 for output_size in output_sizes)
        self.output_size_per_partition = divide(self.output_size, tp_size)
        self.output_partition_sizes = [
            divide(output_size, tp_size) for output_size in self.output_sizes
        ]
        self.gather_output = False
        if output_sizes is None:
            output_sizes = [output_size]
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype
        if quant_config is None:
            self.quant_method = UnquantizedMoELinearMethod()
        else:
            self.quant_method = quant_config.get_quant_method(self,
                                                              prefix=prefix)
            # FIXME(ys): hack for moe
            if isinstance(self.quant_method, UnquantizedLinearMethod):
                self.quant_method = UnquantizedMoELinearMethod()

        assert self.quant_method is not None
        self.quant_method.create_weights(self,
                                         self.input_size,
                                         self.output_partition_sizes,
                                         self.input_size,
                                         self.output_size,
                                         self.params_dtype,
                                         self.num_experts,
                                         weight_loader=self.weight_loader)
        self.register_parameter("bias", None)

    def forward(self,
                input_,
                output: Optional[torch.Tensor] = None,
                expert_idx: int = -1):
        if isinstance(self.quant_method, UnquantizedMoELinearMethod):
            # use optimus moe_ffn outside
            return
        bias = None
        assert self.quant_method is not None

        output = self.quant_method.apply(self,
                                         input_,
                                         bias,
                                         expert_idx=expert_idx,
                                         output=output)
        return output


class QKVReplicatedLinear(ReplicatedLinear):

    def __init__(self,
                 hidden_size: int,
                 head_size: int,
                 total_num_heads: int,
                 total_num_kv_heads: Optional[int] = None,
                 bias: bool = True,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = "",
                 return_bias: bool = True):

        nn.Module.__init__(self)
        self.hidden_size = hidden_size
        self.head_size = head_size
        self.num_heads = total_num_heads
        self.num_kv_heads = total_num_kv_heads if total_num_kv_heads else total_num_heads
        self.input_size = self.hidden_size
        self.output_size = (self.num_heads +
                            2 * self.num_kv_heads) * self.head_size
        self.skip_bias_add = skip_bias_add
        self.return_bias = return_bias
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype
        if quant_config is None:
            self.quant_method: Optional[
                QuantizeMethodBase] = UnquantizedLinearMethod()
        else:
            self.quant_method = quant_config.get_quant_method(self,
                                                              prefix=prefix)

        assert self.quant_method is not None
        self.quant_method.create_weights(self,
                                         self.input_size, [self.output_size],
                                         self.input_size,
                                         self.output_size,
                                         self.params_dtype,
                                         weight_loader=self.weight_loader)
        if bias:
            self.bias = Parameter(
                torch.empty(self.output_size, dtype=self.params_dtype))
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader
            })
        else:
            self.register_parameter("bias", None)

    def weight_loader(self,
                      param: Parameter,
                      loaded_weight: torch.Tensor,
                      loaded_shard_id: Optional[str] = None):
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
        is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
        if loaded_shard_id is None:
            # Loaded weight is already packed.
            assert param_data.shape == loaded_weight.shape
            param_data.copy_(loaded_weight)
            return

        assert loaded_shard_id in ["q", "k", "v"]
        if output_dim is not None:
            if loaded_shard_id == "q":
                shard_offset = 0
                shard_size = self.num_heads * self.head_size
            elif loaded_shard_id == "k":
                shard_offset = self.num_heads * self.head_size
                shard_size = self.num_kv_heads * self.head_size
            elif loaded_shard_id == "v":
                shard_offset = (self.num_heads +
                                self.num_kv_heads) * self.head_size
                shard_size = self.num_kv_heads * self.head_size
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
                shard_size = shard_size // param.pack_factor
                shard_offset = shard_offset // param.pack_factor
                
            if not envs.VLLM_USE_NN or is_quantization:
                param_data = param_data.narrow(output_dim, shard_offset,
                                           shard_size)
            else:
                param_data = param_data.narrow(int(not(output_dim)), shard_offset, 
                                           shard_size)
        else:
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "QKVReplicatedLinear, assume the weight is the same "
                    "for all partitions.")
        
        if envs.VLLM_USE_NN and not is_quantization:
            loaded_weight = loaded_weight.t()
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)
        
        
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
class QKVParallelLinear(ColumnParallelLinear):
    """Linear layers for the attention's QKV transformation.

    Linear layers for the linear transformation of the query, key, and value
    vectors in the attention layer. The weight matrix is concatenated along
    the output dimension. The layer is parallelized along the head dimension.
    When the number of key/value heads is smaller than the number of query
    heads (e.g., multi-query/grouped-query attention), the key/value head may
    be replicated while the query heads are partitioned.

    Args:
        hidden_size: input hidden state size of the transformer.
        head_size: size of each attention head.
        total_num_heads: total number of attention query heads.
        total_num_kv_heads: total number of attention key/value heads. If
                            None, assume total_num_kv_heads = total_num_heads.
        bias: If true, add bias.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
1252
        quant_config: Quantization configure.
1253
1254
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
1255
        return_bias: If true, return bias together with outputs in forward pass.
1256
1257
    """

1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
    def __init__(
        self,
        hidden_size: int,
        head_size: int,
        total_num_heads: int,
        total_num_kv_heads: Optional[int] = None,
        bias: bool = True,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        *,
        return_bias: bool = True,
    ):
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
        self.hidden_size = hidden_size
        self.head_size = head_size
        self.total_num_heads = total_num_heads
        if total_num_kv_heads is None:
            total_num_kv_heads = total_num_heads
        self.total_num_kv_heads = total_num_kv_heads
        # Divide the weight matrix along the last dimension.
        tp_size = get_tensor_model_parallel_world_size()
        self.num_heads = divide(self.total_num_heads, tp_size)
        if tp_size >= self.total_num_kv_heads:
            self.num_kv_heads = 1
            self.num_kv_head_replicas = divide(tp_size,
                                               self.total_num_kv_heads)
        else:
            self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
            self.num_kv_head_replicas = 1
        input_size = self.hidden_size
        output_size = (self.num_heads +
                       2 * self.num_kv_heads) * tp_size * self.head_size
1291
1292
1293
1294
        self.output_sizes = [
            self.num_heads * self.head_size * tp_size,  # q_proj
            self.num_kv_heads * self.head_size * tp_size,  # k_proj
            self.num_kv_heads * self.head_size * tp_size,  # v_proj 
James Fleming's avatar
James Fleming committed
1295
        ]
gaoqiong's avatar
gaoqiong committed
1296

1297
1298
1299
1300
1301
1302
        super().__init__(input_size=input_size,
                         output_size=output_size,
                         bias=bias,
                         gather_output=False,
                         skip_bias_add=skip_bias_add,
                         params_dtype=params_dtype,
1303
                         quant_config=quant_config,
1304
1305
                         prefix=prefix,
                         return_bias=return_bias)
1306

1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
    def _get_shard_offset_mapping(self, loaded_shard_id: str):
        shard_offset_mapping = {
            "q": 0,
            "k": self.num_heads * self.head_size,
            "v": (self.num_heads + self.num_kv_heads) * self.head_size,
            "total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size
        }
        return shard_offset_mapping.get(loaded_shard_id)

    def _get_shard_size_mapping(self, loaded_shard_id: str):
        shard_size_mapping = {
            "q": self.num_heads * self.head_size,
            "k": self.num_kv_heads * self.head_size,
            "v": self.num_kv_heads * self.head_size,
        }
        return shard_size_mapping.get(loaded_shard_id)

    def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
                                           loaded_weight: torch.Tensor):
        """
        Handle special case for models where QKV layers are already 
        fused on disk. In this case, we have no shard id. This function
        determmines the shard id by splitting these layers and then calls
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """
        shard_offsets = [
            # (shard_id, shard_offset, shard_size)
            ("q", 0, self.total_num_heads * self.head_size),
            ("k", self.total_num_heads * self.head_size,
             self.total_num_kv_heads * self.head_size),
            ("v",
             (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
             self.total_num_kv_heads * self.head_size),
        ]

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
1349
1350
            if isinstance(param, (PackedColumnParameter, PackedvLLMParameter
                                  )) and param.packed_dim == param.output_dim:
1351
1352
                shard_size, shard_offset = \
                    param.adjust_shard_indexes_for_packing(
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
                    shard_size=shard_size, shard_offset=shard_offset)

            loaded_weight_shard = loaded_weight.narrow(param.output_dim,
                                                       shard_offset,
                                                       shard_size)
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

    def weight_loader_v2(self,
                         param: BasevLLMParameter,
                         loaded_weight: torch.Tensor,
                         loaded_shard_id: Optional[str] = None):
        if loaded_shard_id is None:  # special case for certain models
1365
            if isinstance(param, PerTensorScaleParameter):
1366
                param.load_qkv_weight(loaded_weight=loaded_weight, shard_id=0)
1367
                return
1368
1369
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
                param.load_qkv_weight(loaded_weight=loaded_weight)
1370
                return
1371
            # TODO: @dsikka - move to parameter.py
1372
1373
1374
1375
1376
1377
1378
1379
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id in ["q", "k", "v"]

        shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
        shard_size = self._get_shard_size_mapping(loaded_shard_id)

1380
1381
1382
1383
1384
1385
1386
1387
1388
        # Note(simon): This is needed for Qwen3's fp8 quantization.
        if isinstance(param, BlockQuantScaleParameter):
            assert self.quant_method is not None
            assert hasattr(self.quant_method, "quant_config")
            weight_block_size = self.quant_method.quant_config.weight_block_size
            block_n, _ = weight_block_size[0], weight_block_size[1]
            shard_offset = (shard_offset + block_n - 1) // block_n
            shard_size = (shard_size + block_n - 1) // block_n

1389
1390
1391
1392
1393
1394
        param.load_qkv_weight(loaded_weight=loaded_weight,
                              num_heads=self.num_kv_head_replicas,
                              shard_id=loaded_shard_id,
                              shard_offset=shard_offset,
                              shard_size=shard_size)

1395
1396
1397
1398
    def weight_loader(self,
                      param: Parameter,
                      loaded_weight: torch.Tensor,
                      loaded_shard_id: Optional[str] = None):
1399
1400
1401
1402
1403

        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
1404
        if is_gguf_weight_type:
1405
            idx_map = {"q": 0, "k": 1, "v": 2}
1406
1407
1408
1409
1410
1411
1412
1413
            if loaded_shard_id is not None:
                param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
                param.shard_weight_type = {
                    k: loaded_weight.item()
                    for k in idx_map
                }
1414
1415
            return

1416
1417
1418
        if is_gguf_weight:
            tp_size = get_tensor_model_parallel_world_size()
            tp_rank = get_tensor_model_parallel_rank()
1419

1420
1421
1422
1423
            output_dim = getattr(param, "output_dim", None)
            shard_size = loaded_weight.size(output_dim) // tp_size
            start_idx = tp_rank * shard_size

1424
1425
1426
1427
1428
1429
            if loaded_shard_id is not None:
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
zhuwenwen's avatar
zhuwenwen committed
1430
1431
                if len(param.data_container) == 3:
                    self.qweight = param.materialize_nested()
1432
                return
1433

1434
1435
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
1436
        # Special case for AQLM codebooks.
James Fleming's avatar
James Fleming committed
1437
        is_metadata = getattr(param, "is_metadata", False)
1438

1439
1440
        # Special case for per-tensor scales in fused case.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
1441
        is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
1442

1443
        if loaded_shard_id is None:
1444
1445
            # Loaded weight is already fused on disk (qkv).
            # (e.g., Phi-3's qkv_proj).
1446
            if output_dim is None:
1447
                if needs_scalar_to_array:
1448
1449
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
                        param_data, loaded_weight, 0)
1450

1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            shard_offsets = [
                # (shard_id, shard_offset, shard_size)
                ("q", 0, self.total_num_heads * self.head_size),
                ("k", self.total_num_heads * self.head_size,
                 self.total_num_kv_heads * self.head_size),
                ("v", (self.total_num_heads + self.total_num_kv_heads) *
                 self.head_size, self.total_num_kv_heads * self.head_size),
            ]
1462
1463
1464
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)

1465
1466
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
1467
                # Special case for Quantized Weights.
1468
1469
1470
1471
1472
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
                    shard_size = shard_size // param.pack_factor
                    shard_offset = shard_offset // param.pack_factor
1473

1474
                    # Special case for Marlin.
1475
1476
1477
                    shard_size, shard_offset = adjust_marlin_shard(
                        param, shard_size, shard_offset)

1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
                if use_bitsandbytes_4bit:
                    orig_qkv_offsets = {
                        "q": (0, self.total_num_heads * self.head_size),
                        "k": (self.total_num_heads * self.head_size,
                              self.total_num_kv_heads * self.head_size),
                        "v":
                        ((self.total_num_heads + self.total_num_kv_heads) *
                         self.head_size,
                         self.total_num_kv_heads * self.head_size),
                        "total":
                        ((self.total_num_heads + 2 * self.total_num_kv_heads) *
                         self.head_size, 0)
                    }

                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
                        param, orig_qkv_offsets, shard_id)

1495
1496
1497
1498
1499
1500
1501
                loaded_weight_shard = loaded_weight.narrow(
                    output_dim, shard_offset, shard_size)
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        tp_rank = get_tensor_model_parallel_rank()
        assert loaded_shard_id in ["q", "k", "v"]
1502
1503

        # If output dim is defined, use the default loading process.
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
        if output_dim is not None:
            if loaded_shard_id == "q":
                shard_offset = 0
                shard_size = self.num_heads * self.head_size
            elif loaded_shard_id == "k":
                shard_offset = self.num_heads * self.head_size
                shard_size = self.num_kv_heads * self.head_size
            elif loaded_shard_id == "v":
                shard_offset = (self.num_heads +
                                self.num_kv_heads) * self.head_size
                shard_size = self.num_kv_heads * self.head_size
1515
            # Special case for Quantized Weights.
1516
1517
1518
1519
1520
1521
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
                shard_size = shard_size // param.pack_factor
                shard_offset = shard_offset // param.pack_factor
1522

1523
                # Special case for Marlin.
1524
1525
1526
                shard_size, shard_offset = adjust_marlin_shard(
                    param, shard_size, shard_offset)

1527
1528
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)
1529
1530
1531
1532
1533
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

1534
            if use_bitsandbytes_4bit:
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
                orig_qkv_offsets = {
                    "q": (0, self.num_heads * self.head_size),
                    "k": (self.num_heads * self.head_size,
                          self.num_kv_heads * self.head_size),
                    "v":
                    ((self.num_heads + self.num_kv_heads) * self.head_size,
                     self.num_kv_heads * self.head_size),
                    "total":
                    ((self.num_heads + 2 * self.num_kv_heads) * self.head_size,
                     0)
                }
1546
                shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
1547
                    param, orig_qkv_offsets, loaded_shard_id)
gaoqiong's avatar
gaoqiong committed
1548

1549
1550
1551
1552
1553
1554
1555
            if not envs.VLLM_USE_NN or len(param_data.shape)==1 or is_quantization:
                param_data = param_data.narrow(output_dim, shard_offset,
                                               shard_size)
            else:
                param_data = param_data.narrow(int(not(output_dim)), shard_offset,
                                               shard_size)
                
zhuwenwen's avatar
zhuwenwen committed
1556
            if loaded_shard_id == "q":
1557
1558
1559
                shard_id = tp_rank
            else:
                shard_id = tp_rank // self.num_kv_head_replicas
1560
            start_idx = shard_id * shard_size
1561

1562
            if not is_sharded_weight:
1563
1564
1565
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)

1566
        # Special case for for AQLM codebooks.
James Fleming's avatar
James Fleming committed
1567
1568
1569
1570
1571
1572
        elif is_metadata:
            # metadata indicates fixed size concatenated along dim 0
            shard_size = loaded_weight.shape[0]
            shard_index = ["q", "k", "v"].index(loaded_shard_id)
            param_data = param_data.narrow(0, shard_index * shard_size,
                                           shard_size)
1573
1574
1575
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
1576
                param_data, loaded_weight, loaded_shard_id)
1577
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
1578
1579
1580
1581
1582
1583
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "QKVParallelLinear, assume the weight is the same "
                    "for all partitions.")
gaoqiong's avatar
gaoqiong committed
1584

1585
1586
1587
        if envs.VLLM_USE_NN and not is_quantization:
            loaded_weight = loaded_weight.t()
            
gaoqiong's avatar
gaoqiong committed
1588
1589
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)
1590
1591


1592
class RowParallelLinear(LinearBase):
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
    """Linear layer with row parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its first dimension and X along its second dimension as:
               -   -
              | A_1 |
              | .   |
          A = | .   |        X = [X_1, ..., X_p]
              | .   |
              | A_p |
               -   -
    Arguments:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias. Note that bias is not parallelized.
        input_is_parallel: If true, we assume that the input is already
                           split across the GPUs and we do not split
                           again.
        skip_bias_add: This was added to enable performance optimization where
                       bias can be fused with other element-wise operations.
                       We skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
1615
1616
1617
        reduce_results: If true, call all-reduce on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y = X_iA_i
1618
        quant_config: Quantization configure.
1619
1620
1621
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.down_proj)
        return_bias: If true, return bias together with outputs in forward pass.
1622
1623
    """

1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        input_is_parallel: bool = True,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        reduce_results: bool = True,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        *,
        return_bias: bool = True,
    ):
1638
1639
1640
1641
1642
1643
1644
        # Divide the weight matrix along the first dimension.
        self.tp_rank = get_tensor_model_parallel_rank()
        self.tp_size = get_tensor_model_parallel_world_size()
        self.input_size_per_partition = divide(input_size, self.tp_size)
        self.output_size_per_partition = output_size
        self.output_partition_sizes = [output_size]

1645
1646
1647
1648
1649
1650
1651
        super().__init__(input_size,
                         output_size,
                         skip_bias_add,
                         params_dtype,
                         quant_config,
                         prefix,
                         return_bias=return_bias)
1652

1653
1654
1655
        self.input_is_parallel = input_is_parallel
        self.reduce_results = reduce_results

1656
        assert self.quant_method is not None
1657
1658
1659
        self.quant_method.create_weights(
            layer=self,
            input_size_per_partition=self.input_size_per_partition,
1660
            output_partition_sizes=self.output_partition_sizes,
1661
1662
1663
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
1664
1665
            weight_loader=(
                self.weight_loader_v2 if self.quant_method.__class__.__name__
1666
                in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
1667
1668
1669
1670
1671
1672
        if not reduce_results and (bias and not skip_bias_add):
            raise ValueError("When not reduce the results, adding bias to the "
                             "results can lead to incorrect results")

        if bias:
            self.bias = Parameter(
1673
                torch.empty(self.output_size, dtype=params_dtype))
1674
1675
1676
1677
1678
1679
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
        else:
            self.register_parameter("bias", None)
1680
        from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce
1681
        self.tbo_all_reduce = tbo_all_reduce
1682
1683
1684

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        tp_rank = get_tensor_model_parallel_rank()
1685
        tp_size = get_tensor_model_parallel_world_size()
1686
        input_dim = getattr(param, "input_dim", None)
1687
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1688
1689
1690
1691
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
1692
1693
1694
1695
1696
1697
1698
1699
1700

        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
1701
1702
1703
1704
            weight_shape = list(loaded_weight.shape)
            if input_dim:
                weight_shape[input_dim] = weight_shape[input_dim] // tp_size
            param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
1705
1706
            
        is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
1707

1708
        param_data = param.data
1709
        if input_dim is not None and not is_sharded_weight:
1710
1711
1712
1713
            if not envs.VLLM_USE_NN or is_quantization:
                shard_size = param_data.shape[input_dim]
            else:
                shard_size = param_data.shape[int(not(input_dim))]
1714
1715
1716
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(input_dim, start_idx,
                                                 shard_size)
1717

1718
1719
1720
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
1721
1722
            loaded_weight = loaded_weight.reshape(1)

1723
1724
1725
        if envs.VLLM_USE_NN and not is_quantization:
            loaded_weight = loaded_weight.t()
            
1726
1727
1728
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

1729
1730
    def weight_loader_v2(self, param: BasevLLMParameter,
                         loaded_weight: torch.Tensor):
1731
1732
1733
1734
1735
1736
1737

        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)

1738
1739
        param.load_row_parallel_weight(loaded_weight=loaded_weight)

1740
    def forward(
1741
        self, input_,
zhuwenwen's avatar
zhuwenwen committed
1742
        use_fused_silu_mul_quant: Optional[bool] = False,
1743
1744
1745
        residual=None,
        output=None,
        disable_allreduce=False
1746
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
1747
1748
1749
1750
1751
1752
1753
        if self.input_is_parallel:
            input_parallel = input_
        else:
            tp_rank = get_tensor_model_parallel_rank()
            splitted_input = split_tensor_along_last_dim(
                input_, num_partitions=self.tp_size)
            input_parallel = splitted_input[tp_rank].contiguous()
1754
1755
1756
1757
1758
            
        # only add residual to the first rank
        if residual is not None and self.tp_size > 1 and get_tensor_model_parallel_rank(
        ) != 0:
            residual *= 0
1759
1760

        # Matrix multiply.
1761
1762
        if output is not None:
            assert disable_allreduce or not self.reduce_results
1763
        assert self.quant_method is not None
1764
1765
1766
        # Only fuse bias add into GEMM for rank 0 (this ensures that
        # bias will not get added more than once in TP>1 case)
        bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
        if use_fused_silu_mul_quant:
            xq, xs = lm_fuse_silu_mul_quant(input_parallel)
            
            silu_quant_args = [xq, xs]
            output_parallel = self.quant_method.apply(self,
                                                      input_parallel,
                                                      bias=bias_,
                                                      silu_quant_args=silu_quant_args)
        else:
            output_parallel = self.quant_method.apply(self,
                                                      input_parallel,
1778
1779
1780
                                                      residual=residual,
                                                      output=output)
        if self.reduce_results and self.tp_size > 1 and not disable_allreduce:
1781
            if envs.VLLM_ENABLE_TBO:
1782
                output_ = self.tbo_all_reduce(output_parallel)
1783
            else:
1784
                output_ = tensor_model_parallel_all_reduce(output_parallel)
1785
        else:
1786
            output_ = output_parallel
1787

1788
1789
1790
1791
1792
1793
1794
1795
        if not self.skip_bias_add:
            output = output_ + self.bias if self.bias is not None else output_
            output_bias = None
        else:
            output = output_
            output_bias = self.bias
            
        # output_bias = self.bias if self.skip_bias_add else None
1796

1797
1798
1799
        # if not self.return_bias:
        #     return output
        
1800
        return output, output_bias
1801
1802
1803
1804
1805
1806
1807
1808

    def extra_repr(self) -> str:
        s = f"input_features={self.input_size_per_partition}"
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={self.tp_size}"
        s += f", reduce_results={self.reduce_results}"
        return s
1809
1810


1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
class QKVCrossParallelLinear(LinearBase):
    """Linear layers for efficient cross-attention's QKV transformation.

    Args:
        hidden_size: input hidden state size of the transformer.
        head_size: size of each attention head.
        total_num_heads: total number of attention query heads.
        total_num_kv_heads: total number of attention key/value heads. If
                            None, assume total_num_kv_heads = total_num_heads.
        bias: If true, add bias.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
        quant_config: Quantization configure.
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
    """
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839

    def __init__(self,
                 hidden_size: int,
                 head_size: int,
                 total_num_heads: int,
                 total_num_kv_heads: Optional[int] = None,
                 bias: bool = True,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
        # input_size and output_size are not used, just for alignment
        input_size = hidden_size
        output_size = (total_num_heads + (total_num_kv_heads or 0)) * head_size
        super().__init__(input_size=input_size,
                         output_size=output_size,
                         skip_bias_add=skip_bias_add,
                         params_dtype=params_dtype,
                         quant_config=quant_config,
                         prefix=prefix)

        self.quant_config = quant_config

1852
        # Empty placeholders for loading as a single module.
1853
1854
1855
1856
1857
1858
1859
1860
1861
        placeholder_size = 0
        assert self.quant_method is not None
        self.quant_method.create_weights(self,
                                         placeholder_size, [placeholder_size],
                                         placeholder_size,
                                         placeholder_size,
                                         self.params_dtype,
                                         weight_loader=self.weight_loader)

1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
        # Use a dictionary to avoid submodules parameters auto-registration:
        # drop-in replacement for a `QKVParallelLinear` module.
        self.proj = dict()
        self.proj["q_proj_decoder"] = ColumnParallelLinear(
            input_size=hidden_size,
            output_size=total_num_heads * head_size,
            bias=bias,
            quant_config=quant_config,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            prefix=f"{prefix}.q_proj_decoder")

        self.proj["kv_proj_encoder"] = QKVParallelLinear(
            hidden_size=hidden_size,
            head_size=head_size,
            total_num_heads=0,
            total_num_kv_heads=total_num_kv_heads,
            bias=bias,
            quant_config=quant_config,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            prefix=f"{prefix}.kv_proj_encoder")

        # `kv_proj_encoder.num_kv_heads` accounts for sharding with tp>1.
1886
        self.q_size = self.q_proj_decoder.output_size_per_partition
1887
1888
1889
1890
1891
        self.kv_size = self.kv_proj_encoder.num_kv_heads * head_size

        if bias:
            self.bias = torch.nn.Parameter()
            set_weight_attrs(self.bias, {
1892
1893
                "output_dim": 0,
                "weight_loader": self.weight_loader,
1894
            })
1895
1896
        else:
            self.bias = None
1897

1898
1899
1900
1901
1902
    def process_weights_after_loading(self):
        for layer in self.proj.values():
            if self.quant_method is not None:
                self.quant_method.process_weights_after_loading(layer)

1903
    @property
1904
1905
1906
    def q_proj_decoder(self) -> ColumnParallelLinear:
        layer = self.proj["q_proj_decoder"]
        for name, param in self.named_parameters():
1907
1908
1909
1910
1911
            target_param = getattr(layer, name, None)
            if target_param is not None:
                self.sync_weight_attrs(param,
                                       target_param,
                                       mode="q_proj_decoder")
1912
        return layer
1913
1914

    @property
1915
1916
1917
    def kv_proj_encoder(self) -> QKVParallelLinear:
        layer = self.proj["kv_proj_encoder"]
        for name, param in self.named_parameters():
1918
1919
1920
1921
1922
            target_param = getattr(layer, name, None)
            if target_param is not None:
                self.sync_weight_attrs(param,
                                       target_param,
                                       mode="kv_proj_encoder")
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
        return layer

    def sync_weight_attrs(
        self,
        src_param: nn.Parameter,
        tgt_param: nn.Parameter,
        mode: Literal["q_proj_decoder", "kv_proj_encoder"],
    ):
        missing_attrs_dict = {
            k: getattr(src_param, k)
1933
1934
            for k in (set(vars(src_param).keys()) -
                      set(vars(tgt_param).keys()))
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
        }
        # TODO(Isotr0py): handle bitsandbytes 8bit
        use_bitsandbytes_4bit = getattr(src_param, "use_bitsandbytes_4bit",
                                        False)
        if (missing_attrs_dict and use_bitsandbytes_4bit):
            q_proj_attrs, kv_proj_attrs = left_shift_bitsandbytes_4bit_shard(
                missing_attrs_dict)
            if mode == "q_proj_decoder":
                set_weight_attrs(tgt_param, q_proj_attrs)
            elif mode == "kv_proj_encoder":
                set_weight_attrs(tgt_param, kv_proj_attrs)
        else:
            set_weight_attrs(tgt_param, missing_attrs_dict)
1948

1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
    def _is_same_param(
        self,
        src_param: torch.nn.Parameter,
        map_param: torch.nn.Parameter,
    ) -> bool:
        """Check if two parameters are exactly pointing to same things."""
        # ignore weight_loader because it's always different
        key_to_ignore = ["weight_loader", "_weight_loader"]
        has_same_type_name = type(src_param) is type(map_param)
        src_param_attrs = {
            k: v
            for k, v in src_param.__dict__.items() if k not in key_to_ignore
        }
        map_param_attrs = {
            k: v
            for k, v in map_param.__dict__.items() if k not in key_to_ignore
        }
        has_same_attrs = src_param_attrs == map_param_attrs
        return has_same_type_name and has_same_attrs

    def select_proj_params(
        self,
        layer: nn.Module,
        param: nn.Parameter,
    ) -> nn.Parameter:
        """
        Given the placeholder param, 
        return the corresponding param in the proj layers.
        """
        target_param_list = [
            v for _, v in layer.named_parameters()
            if self._is_same_param(param, v)
        ]
        assert len(target_param_list) == 1
        target_param = target_param_list[0]
        return target_param

    def forward(  # type: ignore[override]
        self,
        decoder_hidden_states: torch.Tensor,
        encoder_hidden_states: torch.Tensor,
    ) -> tuple[torch.Tensor, ...]:
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
        q, _ = self.q_proj_decoder(decoder_hidden_states)
        if encoder_hidden_states is None:
            # Encoder KV already cached.
            k = None
            v = None
        else:
            # Prefill phase, encoder KV cached here.
            kv_enc, _ = self.kv_proj_encoder(encoder_hidden_states)
            # Split kv in half
            k, v = kv_enc.split(self.kv_size, dim=-1)
        return q, k, v

2003
2004
2005
2006
2007
2008
2009
2010
    def weight_loader(self,
                      param: torch.nn.Parameter,
                      loaded_weight: torch.Tensor,
                      loaded_shard_id: Optional[str] = None):
        layer = (self.q_proj_decoder
                 if loaded_shard_id == "q" else self.kv_proj_encoder)
        target_param = self.select_proj_params(layer, param)
        shard_id_args = (loaded_shard_id, ) if loaded_shard_id != "q" else ()
2011
2012
2013
2014
        if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED:
            layer.weight_loader_v2(target_param, loaded_weight, *shard_id_args)
        else:
            layer.weight_loader(target_param, loaded_weight, *shard_id_args)
2015
2016
2017

    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
2018
        s += f", q_size={self.q_size}"
2019
2020
2021
2022
        s += f", kv_size={self.kv_size}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={get_tensor_model_parallel_world_size()}"
        s += ", gather_output=False"
zhuwenwen's avatar
zhuwenwen committed
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
        return s
    
    
class RowParallelMoELinear(RowParallelLinear):

    def __init__(self,
                 num_experts: int,
                 input_size: int,
                 output_size: int,
                 params_dtype: Optional[torch.dtype] = None,
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
        torch.nn.Module.__init__(self)
        self.num_experts = num_experts
        self.input_size = input_size
        self.output_size = output_size
        self.reduce_results = False
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype
        if quant_config is None:
            self.quant_method: Optional[
                QuantizeMethodBase] = UnquantizedMoELinearMethod()
        else:
            self.quant_method = quant_config.get_quant_method(self,
                                                              prefix=prefix)
            # FIXME(ys): hack for moe
            if isinstance(self.quant_method, UnquantizedLinearMethod):
                self.quant_method = UnquantizedMoELinearMethod()

        self.tp_size = get_tensor_model_parallel_world_size()
        self.input_size_per_partition = divide(input_size, self.tp_size)
        assert self.quant_method is not None
        self.quant_method.create_weights(self,
                                         self.input_size_per_partition,
                                         [self.output_size],
                                         self.input_size,
                                         self.output_size,
                                         self.params_dtype,
                                         self.num_experts,
                                         weight_loader=self.weight_loader)
        self.register_parameter("bias", None)

    def forward(  # type: ignore[override]
            self,
            input_,
            residual=None,
            expert_idx: int = -1,
            output: Optional[torch.Tensor] = None):
        if isinstance(self.quant_method, UnquantizedMoELinearMethod):
            # use optimus moe_ffn outside
            return
        bias = None
        assert self.quant_method is not None
        output = self.quant_method.apply(self,
                                         input_,
                                         bias,
                                         expert_idx=expert_idx,
                                         output=output)
        return output