linear.py 58.6 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import itertools
4
from abc import abstractmethod
5
from typing import Optional, List
6

7
import flux
8
9
import torch
import torch.nn.functional as F
10
from torch.nn.parameter import Parameter, UninitializedParameter
11

12
13
14
15
16
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size,
                              split_tensor_along_last_dim,
                              tensor_model_parallel_all_gather,
                              tensor_model_parallel_all_reduce)
17
from vllm.distributed.parallel_state import get_tp_group
18
from vllm.logger import init_logger
19
20
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig, QuantizeMethodBase)
21
# yapf: disable
22
from vllm.model_executor.parameter import (BasevLLMParameter,
23
                                           BlockQuantScaleParameter,
24
                                           PackedColumnParameter,
25
                                           PackedvLLMParameter,
26
27
                                           PerTensorScaleParameter,
                                           RowvLLMParameter)
28
# yapf: enable
29
from vllm.model_executor.utils import set_weight_attrs
gaoqiong's avatar
gaoqiong committed
30

zhuwenwen's avatar
zhuwenwen committed
31
import os
32
from vllm.model_executor.utils import gemm_bank_conf
33
34
35

logger = init_logger(__name__)

36
WEIGHT_LOADER_V2_SUPPORTED = [
37
    "CompressedTensorsLinearMethod", "AWQMarlinLinearMethod",
38
    "AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod",
39
    "MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod",
40
    "TPUInt8LinearMethod", "GPTQLinearMethod", "FBGEMMFp8LinearMethod",
41
    "ModelOptFp8LinearMethod", "IPEXAWQLinearMethod", "IPEXGPTQLinearMethod",
42
    "HQQMarlinMethod", "QuarkLinearMethod", "BlockInt8LinearMethod",
43
]
44

45

46
47
48
49
50
51
52
53
def adjust_marlin_shard(param, shard_size, shard_offset):
    marlin_tile_size = getattr(param, "marlin_tile_size", None)
    if marlin_tile_size is None:
        return shard_size, shard_offset

    return shard_size * marlin_tile_size, shard_offset * marlin_tile_size


54
def adjust_bitsandbytes_4bit_shard(param: Parameter,
55
56
                                   shard_offsets: dict[str, tuple[int, int]],
                                   loaded_shard_id: str) -> tuple[int, int]:
57
58
    """Adjust the quantization offsets and sizes for BitsAndBytes sharding."""

59
60
    total, _ = shard_offsets["total"]
    orig_offset, orig_size = shard_offsets[loaded_shard_id]
61
62
63
64
65
66
67
68

    quantized_total = param.data.shape[0]
    quantized_offset = orig_offset * quantized_total // total
    quantized_size = orig_size * quantized_total // total

    return quantized_size, quantized_offset


69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
    """For fused modules (QKV and MLP) we have an array of length
    N that holds 1 scale for each "logical" matrix. So the param
    is an array of length N. The loaded_weight corresponds to 
    one of the shards on disk. Here, we slice the param based on 
    the shard_id for loading.
    """
    qkv_idxs = {"q": 0, "k": 1, "v": 2}

    if isinstance(shard_id, str):
        shard_id = qkv_idxs[shard_id]
    elif not isinstance(shard_id, int):
        raise ValueError(f"Unknown Shard Id {shard_id}")

    # AutoFP8 scales do not have a shape
    # compressed-tensors scales do have a shape
    if len(loaded_weight.shape) != 0:
        assert loaded_weight.shape[0] == 1
        loaded_weight = loaded_weight[0]

    return param[shard_id], loaded_weight


92
class LinearMethodBase(QuantizeMethodBase):
93
94
95
    """Base class for different (maybe quantized) linear methods."""

    @abstractmethod
96
97
    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
98
                       output_partition_sizes: list[int], input_size: int,
99
100
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
101
102
        """Create weights for a linear layer. 
           The weights will be set as attributes of the layer.
103

104
105
106
107
108
109
110
111
112
113
        Args:
            layer: The layer that is using the LinearMethodBase factory.
            input_size_per_partition: Size of the weight input dim on rank X.
            output_partition_sizes: Sizes of the output dim of each logical 
                weight on rank X. E.g., output_partition_sizes for QKVLinear
                is a list contains the width of Wq, Wk, Wv on rank X.
            input_size: Size of the input dim of the weight across all ranks.
            output_size: Size of the output dim of the weight across all ranks.
            params_dtype: Datatype of the parameters.
        """
114
115
116
        raise NotImplementedError

    @abstractmethod
117
118
119
120
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
121
122
        """Apply the weights in layer to the input tensor.
        Expects create_weights to have been called before on the layer."""
123
124
125
126
        raise NotImplementedError


class UnquantizedLinearMethod(LinearMethodBase):
127
    """Linear method without quantization."""
128
129
    
    def __init__(self):
zhuwenwen's avatar
zhuwenwen committed
130
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
131
        self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
gaoqiong's avatar
gaoqiong committed
132
        
133
134
    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
135
                       output_partition_sizes: list[int], input_size: int,
136
137
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
138
        weight = Parameter(torch.empty(sum(output_partition_sizes),
CHU Tianxiang's avatar
CHU Tianxiang committed
139
                                       input_size_per_partition,
140
141
142
                                       dtype=params_dtype),
                           requires_grad=False)
        set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
143
144
        layer.register_parameter("weight", weight)
        set_weight_attrs(weight, extra_weight_attrs)
145

146
147
148
149
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
150

zhuwenwen's avatar
zhuwenwen committed
151
        if self.use_llama_nn:
152
153
            if gemm_bank_conf(layer.weight.shape[1] - 32) and os.environ['GEMM_PAD'] == '1':
                layer.weight = layer.weight[:,:-32]
154
                
zhuwenwen's avatar
zhuwenwen committed
155
            if bias is not None:
zhuwenwen's avatar
zhuwenwen committed
156
                if len(x.shape) == 2: 
157
                    return torch.addmm(bias, x, layer.weight)
zhuwenwen's avatar
zhuwenwen committed
158
                else:
159
                    return torch.matmul(x, layer.weight) + bias
zhuwenwen's avatar
zhuwenwen committed
160
            else:
161
                return torch.matmul(x, layer.weight)
zhuwenwen's avatar
zhuwenwen committed
162
        else:
163
            return F.linear(x, layer.weight, bias)
164

165

166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
class GemmRS(LinearMethodBase):
    #Fused Gemm-ReduceScatter without quantization.

    def __init__(self, separate_bias_add: bool = False):
        self.separate_bias_add = separate_bias_add
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'

    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
                       output_partition_sizes: List[int], input_size: int,
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
        weight = Parameter(torch.empty(sum(output_partition_sizes),
                                       input_size_per_partition,
                                       dtype=params_dtype),
                           requires_grad=False)
        set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
        layer.register_parameter("weight", weight)
        set_weight_attrs(weight, extra_weight_attrs)

        if self.use_llama_nn:
            self.gemm_rs_op = flux.GemmRS(
                get_tp_group().device_group,
                nnodes=1,  # One node
                max_m=8192,  # Max M. TODO: Pass in correctly.
                n_dim=output_size,  # N
                # TODO: Pass in input dtype correctly.
                # TODO: It would be nicer to modify flux to dispatch based on dtype
                # at run time, but I don't know what the downside would be.
                # Similar comment for max m.
                input_dtype=params_dtype, #torch.float16,
                # Note: transpose_weight=False means that B is transposed
                transpose_weight=True,
                # Note: bfloat16 requires fuse_reduction=False.
                fuse_reduction=False,
            )
        else:
            self.gemm_rs_op = flux.GemmRS(
                get_tp_group().device_group,
                nnodes=1,  # One node
                max_m=8192,  # Max M. TODO: Pass in correctly.
                n_dim=output_size,  # N
                # TODO: Pass in input dtype correctly.
                # TODO: It would be nicer to modify flux to dispatch based on dtype
                # at run time, but I don't know what the downside would be.
                # Similar comment for max m.
                input_dtype=params_dtype, #torch.float16,
                # Note: transpose_weight=False means that B is transposed
                transpose_weight=False,
                # Note: bfloat16 requires fuse_reduction=False.
                fuse_reduction=False,
            )

    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
        assert bias is None

        output = self.gemm_rs_op.forward(x, layer.weight)
        output = output.squeeze(0)

        return output


class AGCook(LinearMethodBase):
    #Fused AllGather-Gemm without quantization.

    def __init__(self, separate_bias_add: bool = False):
        self.separate_bias_add = separate_bias_add
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'

    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
                       output_partition_sizes: List[int], input_size: int,
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
        weight = Parameter(torch.empty(sum(output_partition_sizes),
                                       input_size_per_partition,
                                       dtype=params_dtype),
                           requires_grad=False)
        set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
        layer.register_parameter("weight", weight)
        set_weight_attrs(weight, extra_weight_attrs)

        if self.use_llama_nn:
            self.ag_gemm_op = flux.AGKernel(
                get_tp_group().device_group,
                nnodes=1,  # One node
                full_m=8192,  # Max M. TODO: Pass in correctly.
                n_dim=weight.shape[0],  # N
                k_dim=weight.shape[1],  # K
                # TODO: Pass in input dtype correctly.
                # TODO: It would be nicer to modify flux to dispatch based on dtype
                # at run time, but I don't know what the downside would be.
                # Similar comment for max m.
                input_dtype=params_dtype, #torch.float16,
                output_dtype=params_dtype, #torch.float16,
                # Note: transpose_weight=False means that B is transposed
                transpose_weight=True,
                # Note: if local_copy=True, I hit the following runtime error:
                # /flux/src/all_gather/ths_op/all_gather_gemm_kernel.cc:648
                #   Check failed: 33554432((input.numel() * input.element_size()))
                #                 == 139836453421056((this->chunk_size))
                local_copy=False,
            )
        else:
            self.ag_gemm_op = flux.AGKernel(
                get_tp_group().device_group,
                nnodes=1,  # One node
                full_m=8192,  # Max M. TODO: Pass in correctly.
                n_dim=weight.shape[0],  # N
                k_dim=weight.shape[1],  # K
                # TODO: Pass in input dtype correctly.
                # TODO: It would be nicer to modify flux to dispatch based on dtype
                # at run time, but I don't know what the downside would be.
                # Similar comment for max m.
                input_dtype=params_dtype, #torch.float16,
                output_dtype=params_dtype, #torch.float16,
                # Note: transpose_weight=False means that B is transposed
                transpose_weight=False,
                # Note: if local_copy=True, I hit the following runtime error:
                # /flux/src/all_gather/ths_op/all_gather_gemm_kernel.cc:648
                #   Check failed: 33554432((input.numel() * input.element_size()))
                #                 == 139836453421056((this->chunk_size))
                local_copy=False,
            )

    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
        assert bias is None

        output = self.ag_gemm_op.forward(x, layer.weight)

        return output
    
304
305
class LinearBase(torch.nn.Module):
    """Base linear layer.
306
307
308
309
310
311
312

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        bias: If true, add bias.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
313
        quant_config: Quantization configure.
314
315
316
317
318
319
320
321
    """

    def __init__(
        self,
        input_size: int,
        output_size: int,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
322
        quant_config: Optional[QuantizationConfig] = None,
323
        prefix: str = "",
324
325
        fuse_gemm_rs: bool = False,
        fuse_ag_gemm: bool = False,
326
327
328
329
330
331
332
333
334
335
    ):
        super().__init__()

        # Keep input parameters
        self.input_size = input_size
        self.output_size = output_size
        self.skip_bias_add = skip_bias_add
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype
336
337
338
339
340
341
342
343
        if fuse_gemm_rs:
            assert (quant_config is None)
            self.quant_method: Optional[QuantizeMethodBase] = GemmRS()
        elif fuse_ag_gemm:
            assert (quant_config is None)
            self.quant_method = AGCook()
        elif quant_config is None:
            self.quant_method = UnquantizedLinearMethod()
344
        else:
345
346
            self.quant_method = quant_config.get_quant_method(self,
                                                              prefix=prefix)
347

348
349
    def forward(self,
                x: torch.Tensor) -> tuple[torch.Tensor, Optional[Parameter]]:
350
351
352
353
354
355
356
357
358
359
360
361
362
        raise NotImplementedError


class ReplicatedLinear(LinearBase):
    """Replicated linear layer.

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        bias: If true, add bias.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
        quant_config: Quantization configure.
363
364
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
365
366
    """

367
368
369
370
371
372
    def __init__(self,
                 input_size: int,
                 output_size: int,
                 bias: bool = True,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
373
                 quant_config: Optional[QuantizationConfig] = None,
374
375
376
377
378
379
380
                 prefix: str = ""):
        super().__init__(input_size,
                         output_size,
                         skip_bias_add,
                         params_dtype,
                         quant_config,
                         prefix=prefix)
381

382
383
        # All the linear layer supports quant method.
        assert self.quant_method is not None
384
385
386
387
388
        self.quant_method.create_weights(self,
                                         self.input_size, [self.output_size],
                                         self.input_size,
                                         self.output_size,
                                         self.params_dtype,
389
                                         weight_loader=self.weight_loader)
390

391
392
        if bias:
            self.bias = Parameter(
393
                torch.empty(self.output_size, dtype=self.params_dtype))
394
395
396
397
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
398
399
400
        else:
            self.register_parameter("bias", None)

401
402
403
404
405
406
407
408
409
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        # If the weight on disk does not have a shape, give it one
        # (such scales for AutoFp8).
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)

        assert param.size() == loaded_weight.size()
        param.data.copy_(loaded_weight)

410
411
    def forward(self,
                x: torch.Tensor) -> tuple[torch.Tensor, Optional[Parameter]]:
412
        bias = self.bias if not self.skip_bias_add else None
413
        assert self.quant_method is not None
414
        output = self.quant_method.apply(self, x, bias)
415
416
417
        output_bias = self.bias if self.skip_bias_add else None
        return output, output_bias

418
419
420
421
422
423
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        return s

424

425
class ColumnParallelLinear(LinearBase):
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
    """Linear layer with column parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its second dimension as A = [A_1, ..., A_p].

    Args:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y_i = XA_i
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
442
        quant_config: Quantization configure.
James Fleming's avatar
James Fleming committed
443
444
        output_sizes: list of output sizes packed into one output, like for QKV
                       the list would be size 3.
445
446
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj) 
447
448
    """

449
450
451
452
453
454
455
456
    def __init__(self,
                 input_size: int,
                 output_size: int,
                 bias: bool = True,
                 gather_output: bool = False,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
                 quant_config: Optional[QuantizationConfig] = None,
457
                 output_sizes: Optional[list[int]] = None,
458
459
                 prefix: str = "",
                 fuse_ag_gemm: bool = False):
460
        super().__init__(input_size, output_size, skip_bias_add, params_dtype,
461
                         quant_config, prefix, fuse_ag_gemm=fuse_ag_gemm)
462
463

        self.gather_output = gather_output
464

465
466
        # Divide the weight matrix along the last dimension.
        tp_size = get_tensor_model_parallel_world_size()
467
468
469
470
471
472
473
474
475
476
        assert self.quant_method is not None
        self.output_size_per_partition = divide(self.output_size, tp_size)
        self.output_partition_sizes = [self.output_size_per_partition]
        # If QKV or MergedColumn, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = [
                divide(output_size, tp_size)
                for output_size in self.output_sizes
            ]

James Fleming's avatar
James Fleming committed
477
478
        if output_sizes is None:
            output_sizes = [output_size]
479

480
481
482
483
484
485
486
        self.quant_method.create_weights(
            layer=self,
            input_size_per_partition=self.input_size,
            output_partition_sizes=self.output_partition_sizes,
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
487
488
            weight_loader=(
                self.weight_loader_v2 if self.quant_method.__class__.__name__
489
                in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
490
491
492
493
494
495
496
497
498
499
500
501
502
503
        if bias:
            self.bias = Parameter(
                torch.empty(self.output_size_per_partition,
                            dtype=params_dtype))
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
        else:
            self.register_parameter("bias", None)

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        tp_rank = get_tensor_model_parallel_rank()
        output_dim = getattr(param, "output_dim", None)
504
505
506
507
508
509
510
511
512
513
514

        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
            param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)

515
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
516
517
518
519
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
520

521
        param_data = param.data
522
        if output_dim is not None and not is_sharded_weight:
523
524
525
526
            shard_size = param_data.shape[output_dim]
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                 shard_size)
527
528
529
530
531

        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)
532

533
534
535
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

536
    def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor):
537
538
539
540
541
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)
542
543
        param.load_column_parallel_weight(loaded_weight=loaded_weight)

544
    def forward(self, input_) -> tuple[torch.Tensor, Optional[Parameter]]:
545
546
547
        bias = self.bias if not self.skip_bias_add else None

        # Matrix multiply.
548
        assert self.quant_method is not None
549
        output_parallel = self.quant_method.apply(self, input_, bias)
550
551
552
553
554
555
556
557
        if self.gather_output:
            # All-gather across the partitions.
            output = tensor_model_parallel_all_gather(output_parallel)
        else:
            output = output_parallel
        output_bias = self.bias if self.skip_bias_add else None
        return output, output_bias

558
559
560
561
562
563
564
565
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size_per_partition}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={get_tensor_model_parallel_world_size()}"
        s += f", gather_output={self.gather_output}"
        return s

566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584

class MergedColumnParallelLinear(ColumnParallelLinear):
    """Packed linear layers with column parallelism.

    Similar to ColumnParallelLinear, but the weight matrix is concatenated
    along the output dimension. When the weight matrix is loaded, the
    different partitions are sharded separately.

    Args:
        input_size: input dimension of the linear layer.
        output_sizes: list of output dimensions of the linear layer.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make the output
                       available to all GPUs, otherwise, every GPU will have
                       its own output.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
585
        quant_config: Quantization configure.
586
587
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
588
589
    """

590
591
    def __init__(self,
                 input_size: int,
592
                 output_sizes: list[int],
593
594
595
596
                 bias: bool = True,
                 gather_output: bool = False,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
597
                 quant_config: Optional[QuantizationConfig] = None,
598
599
                 prefix: str = "",
                 fuse_ag_gemm: bool = False):
600
601
602
        self.output_sizes = output_sizes
        tp_size = get_tensor_model_parallel_world_size()
        assert all(output_size % tp_size == 0 for output_size in output_sizes)
603
604
605
606
607
608
        super().__init__(input_size=input_size,
                         output_size=sum(output_sizes),
                         bias=bias,
                         gather_output=gather_output,
                         skip_bias_add=skip_bias_add,
                         params_dtype=params_dtype,
609
                         quant_config=quant_config,
610
611
                         prefix=prefix,
                         fuse_ag_gemm=fuse_ag_gemm)
612
613
614
615
616

    def weight_loader(self,
                      param: Parameter,
                      loaded_weight: torch.Tensor,
                      loaded_shard_id: Optional[int] = None):
James Fleming's avatar
James Fleming committed
617

618
619
620
621
622
        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
623
624
625
626
627
628
629
630
            if loaded_shard_id is not None:
                param.data[loaded_shard_id].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
                param.shard_weight_type = {
                    i: loaded_weight.item()
                    for i, _ in enumerate(self.output_sizes)
                }
631
632
            return

633
634
635
        if is_gguf_weight:
            tp_size = get_tensor_model_parallel_world_size()
            tp_rank = get_tensor_model_parallel_rank()
636

637
638
639
            output_dim = getattr(param, "output_dim", None)
            shard_size = loaded_weight.size(output_dim) // tp_size
            start_idx = tp_rank * shard_size
640

641
642
643
644
645
646
647
648
649
            if loaded_shard_id is not None:
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                if len(param.data_container) == 2:
                    self.qweight = param.materialize_nested()
                return
650

651
652
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
653
        # Special case for AQLM codebooks.
James Fleming's avatar
James Fleming committed
654
        is_metadata = getattr(param, "is_metadata", False)
655
656
        # Special case for per-tensor scale to load scalar into fused array.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
657

658
        if loaded_shard_id is None:
659
660
            # Loaded weight is already fused on disk (mlp).
            # (e.g., Phi-3's gate_up_proj).
661
            if output_dim is None:
662
                if needs_scalar_to_array:
663
664
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
                        param_data, loaded_weight, 0)
665

666
667
668
669
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            current_shard_offset = 0
670
671
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)
672
            shard_offsets: list[tuple[int, int, int]] = []
673
674
675
676
677
            for i, output_size in enumerate(self.output_sizes):
                shard_offsets.append((i, current_shard_offset, output_size))
                current_shard_offset += output_size
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
678
                # Special case for Quantization.
679
680
681
682
683
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
                    shard_size = shard_size // param.pack_factor
                    shard_offset = shard_offset // param.pack_factor
684
                    # Special case for Marlin.
685
686
687
                    shard_size, shard_offset = adjust_marlin_shard(
                        param, shard_size, shard_offset)

688
                if use_bitsandbytes_4bit:
689
690
691
692
693
694
695
696
697
                    index = list(itertools.accumulate([0] + self.output_sizes))
                    orig_offsets = {
                        str(i): (index[i], size)
                        for i, size in enumerate(self.output_sizes)
                    }
                    orig_offsets["total"] = (self.output_size, 0)
                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
                        param, orig_offsets, str(shard_id))

698
699
700
701
702
703
704
705
706
707
708
                loaded_weight_shard = loaded_weight.narrow(
                    output_dim, shard_offset, shard_size)
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id < len(self.output_sizes)
        tp_rank = get_tensor_model_parallel_rank()
        tp_size = get_tensor_model_parallel_world_size()
        if output_dim is not None:
            shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
            shard_size = self.output_sizes[loaded_shard_id] // tp_size
709
            # Special case for quantization.
710
711
712
713
714
715
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
                shard_size = shard_size // param.pack_factor
                shard_offset = shard_offset // param.pack_factor
716
                # Special case for Marlin.
717
718
                shard_size, shard_offset = adjust_marlin_shard(
                    param, shard_size, shard_offset)
gaoqiong's avatar
gaoqiong committed
719

720
721
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)
722
723
724
725
726
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

727
            if use_bitsandbytes_4bit:
728
729
730
                shard_size = loaded_weight.shape[output_dim]
                shard_offset = loaded_weight.shape[output_dim] * \
                    loaded_shard_id
731

gaoqiong's avatar
gaoqiong committed
732
733
            param_data = param_data.narrow(output_dim, shard_offset,
                                           shard_size)
734
            start_idx = tp_rank * shard_size
735
            if not is_sharded_weight:
736
737
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)
738
        # Special case for AQLM codebooks.
James Fleming's avatar
James Fleming committed
739
740
741
742
743
        elif is_metadata:
            # metadata indicates fixed size concatenated along dim 0
            shard_size = loaded_weight.shape[0]
            shard_offset = loaded_shard_id * shard_size
            param_data = param_data.narrow(0, shard_offset, shard_size)
744

745
746
747
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
748
749
                param_data, loaded_weight, loaded_shard_id)

750
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
751
752
753
754
755
756
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "MergedColumnParallelLinear, assume the weight is "
                    "the same for all partitions.")
757

gaoqiong's avatar
gaoqiong committed
758
759
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)
760

761
762
763
764
765
766
767
768
769
770
771
772
773
    def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
                                           loaded_weight: torch.Tensor):
        """
        Handle special case for models where MLP layers are already
        fused on disk. In this case, we have no shard id. This function
        determmines the shard id by splitting these layers and then calls
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """

        current_shard_offset = 0
774
        shard_offsets: list[tuple[int, int, int]] = []
775
776
777
778
779
780
781
782
        for i, output_size in enumerate(self.output_sizes):
            shard_offsets.append((i, current_shard_offset, output_size))
            current_shard_offset += output_size

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
783
784
            if isinstance(param, (PackedColumnParameter, PackedvLLMParameter
                                  )) and param.packed_dim == param.output_dim:
785
786
                shard_size, shard_offset = \
                    param.adjust_shard_indexes_for_packing(
787
788
789
790
791
792
793
794
795
796
797
798
                    shard_size=shard_size, shard_offset=shard_offset)

            loaded_weight_shard = loaded_weight.narrow(param.output_dim,
                                                       shard_offset,
                                                       shard_size)
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

    def weight_loader_v2(self,
                         param: BasevLLMParameter,
                         loaded_weight: torch.Tensor,
                         loaded_shard_id: Optional[int] = None):
        if loaded_shard_id is None:
799
800
801
802
            if isinstance(param, PerTensorScaleParameter):
                param.load_merged_column_weight(loaded_weight=loaded_weight,
                                                shard_id=0)
                return
803
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
804
                param.load_merged_column_weight(loaded_weight=loaded_weight)
805
                return
806
            # TODO: @dsikka - move to parameter.py
807
808
809
810
811
812
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id < len(self.output_sizes)

        tp_size = get_tensor_model_parallel_world_size()
813
814
815
816

        if isinstance(param, BlockQuantScaleParameter):
            from vllm.model_executor.layers.quantization.fp8 import (
                Fp8LinearMethod, Fp8MoEMethod)
817
818
819
            
            from vllm.model_executor.layers.quantization.blockwise_int8 import (
                BlockInt8LinearMethod, BlockInt8MoEMethod)
820
821
            assert self.quant_method is not None
            assert isinstance(self.quant_method,
822
                              (Fp8LinearMethod, Fp8MoEMethod, BlockInt8LinearMethod, BlockInt8MoEMethod))
823
824
825
826
827
828
829
830
831
832
833
            weight_block_size = self.quant_method.quant_config.weight_block_size
            assert weight_block_size is not None
            block_n, _ = weight_block_size[0], weight_block_size[1]
            shard_offset = (
                (sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) //
                block_n) // tp_size
            shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) //
                          block_n // tp_size)
        else:
            shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
            shard_size = self.output_sizes[loaded_shard_id] // tp_size
834
835
836
837
838
839

        param.load_merged_column_weight(loaded_weight=loaded_weight,
                                        shard_id=loaded_shard_id,
                                        shard_offset=shard_offset,
                                        shard_size=shard_size)

840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861

class QKVParallelLinear(ColumnParallelLinear):
    """Linear layers for the attention's QKV transformation.

    Linear layers for the linear transformation of the query, key, and value
    vectors in the attention layer. The weight matrix is concatenated along
    the output dimension. The layer is parallelized along the head dimension.
    When the number of key/value heads is smaller than the number of query
    heads (e.g., multi-query/grouped-query attention), the key/value head may
    be replicated while the query heads are partitioned.

    Args:
        hidden_size: input hidden state size of the transformer.
        head_size: size of each attention head.
        total_num_heads: total number of attention query heads.
        total_num_kv_heads: total number of attention key/value heads. If
                            None, assume total_num_kv_heads = total_num_heads.
        bias: If true, add bias.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
862
        quant_config: Quantization configure.
863
864
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
865
866
    """

867
868
869
870
871
872
873
874
    def __init__(self,
                 hidden_size: int,
                 head_size: int,
                 total_num_heads: int,
                 total_num_kv_heads: Optional[int] = None,
                 bias: bool = True,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
875
                 quant_config: Optional[QuantizationConfig] = None,
876
877
                 prefix: str = "",
                 fuse_ag_gemm: bool = False):
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
        self.hidden_size = hidden_size
        self.head_size = head_size
        self.total_num_heads = total_num_heads
        if total_num_kv_heads is None:
            total_num_kv_heads = total_num_heads
        self.total_num_kv_heads = total_num_kv_heads
        # Divide the weight matrix along the last dimension.
        tp_size = get_tensor_model_parallel_world_size()
        self.num_heads = divide(self.total_num_heads, tp_size)
        if tp_size >= self.total_num_kv_heads:
            self.num_kv_heads = 1
            self.num_kv_head_replicas = divide(tp_size,
                                               self.total_num_kv_heads)
        else:
            self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
            self.num_kv_head_replicas = 1
        input_size = self.hidden_size
        output_size = (self.num_heads +
                       2 * self.num_kv_heads) * tp_size * self.head_size
897
898
899
900
        self.output_sizes = [
            self.num_heads * self.head_size * tp_size,  # q_proj
            self.num_kv_heads * self.head_size * tp_size,  # k_proj
            self.num_kv_heads * self.head_size * tp_size,  # v_proj 
James Fleming's avatar
James Fleming committed
901
        ]
gaoqiong's avatar
gaoqiong committed
902

903
904
905
906
907
908
        super().__init__(input_size=input_size,
                         output_size=output_size,
                         bias=bias,
                         gather_output=False,
                         skip_bias_add=skip_bias_add,
                         params_dtype=params_dtype,
909
                         quant_config=quant_config,
910
911
                         prefix=prefix,
                         fuse_ag_gemm=fuse_ag_gemm)
912

913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
    def _get_shard_offset_mapping(self, loaded_shard_id: str):
        shard_offset_mapping = {
            "q": 0,
            "k": self.num_heads * self.head_size,
            "v": (self.num_heads + self.num_kv_heads) * self.head_size,
            "total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size
        }
        return shard_offset_mapping.get(loaded_shard_id)

    def _get_shard_size_mapping(self, loaded_shard_id: str):
        shard_size_mapping = {
            "q": self.num_heads * self.head_size,
            "k": self.num_kv_heads * self.head_size,
            "v": self.num_kv_heads * self.head_size,
        }
        return shard_size_mapping.get(loaded_shard_id)

    def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
                                           loaded_weight: torch.Tensor):
        """
        Handle special case for models where QKV layers are already 
        fused on disk. In this case, we have no shard id. This function
        determmines the shard id by splitting these layers and then calls
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """
        shard_offsets = [
            # (shard_id, shard_offset, shard_size)
            ("q", 0, self.total_num_heads * self.head_size),
            ("k", self.total_num_heads * self.head_size,
             self.total_num_kv_heads * self.head_size),
            ("v",
             (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
             self.total_num_kv_heads * self.head_size),
        ]

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
955
956
            if isinstance(param, (PackedColumnParameter, PackedvLLMParameter
                                  )) and param.packed_dim == param.output_dim:
957
958
                shard_size, shard_offset = \
                    param.adjust_shard_indexes_for_packing(
959
960
961
962
963
964
965
966
967
968
969
970
                    shard_size=shard_size, shard_offset=shard_offset)

            loaded_weight_shard = loaded_weight.narrow(param.output_dim,
                                                       shard_offset,
                                                       shard_size)
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

    def weight_loader_v2(self,
                         param: BasevLLMParameter,
                         loaded_weight: torch.Tensor,
                         loaded_shard_id: Optional[str] = None):
        if loaded_shard_id is None:  # special case for certain models
971
            if isinstance(param, PerTensorScaleParameter):
972
                param.load_qkv_weight(loaded_weight=loaded_weight, shard_id=0)
973
                return
974
975
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
                param.load_qkv_weight(loaded_weight=loaded_weight)
976
                return
977
            # TODO: @dsikka - move to parameter.py
978
979
980
981
982
983
984
985
986
987
988
989
990
991
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id in ["q", "k", "v"]

        shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
        shard_size = self._get_shard_size_mapping(loaded_shard_id)

        param.load_qkv_weight(loaded_weight=loaded_weight,
                              num_heads=self.num_kv_head_replicas,
                              shard_id=loaded_shard_id,
                              shard_offset=shard_offset,
                              shard_size=shard_size)

992
993
994
995
    def weight_loader(self,
                      param: Parameter,
                      loaded_weight: torch.Tensor,
                      loaded_shard_id: Optional[str] = None):
996
997
998
999
1000

        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
1001
        if is_gguf_weight_type:
1002
            idx_map = {"q": 0, "k": 1, "v": 2}
1003
1004
1005
1006
1007
1008
1009
1010
            if loaded_shard_id is not None:
                param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
                param.shard_weight_type = {
                    k: loaded_weight.item()
                    for k in idx_map
                }
1011
1012
            return

1013
1014
1015
        if is_gguf_weight:
            tp_size = get_tensor_model_parallel_world_size()
            tp_rank = get_tensor_model_parallel_rank()
1016

1017
1018
1019
1020
            output_dim = getattr(param, "output_dim", None)
            shard_size = loaded_weight.size(output_dim) // tp_size
            start_idx = tp_rank * shard_size

1021
1022
1023
1024
1025
1026
1027
1028
1029
            if loaded_shard_id is not None:
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                if len(param.data_container) == 3:
                    self.qweight = param.materialize_nested()
                return
1030

1031
1032
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
1033
        # Special case for AQLM codebooks.
James Fleming's avatar
James Fleming committed
1034
        is_metadata = getattr(param, "is_metadata", False)
1035

1036
1037
        # Special case for per-tensor scales in fused case.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
1038

1039
        if loaded_shard_id is None:
1040
1041
            # Loaded weight is already fused on disk (qkv).
            # (e.g., Phi-3's qkv_proj).
1042
            if output_dim is None:
1043
                if needs_scalar_to_array:
1044
1045
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
                        param_data, loaded_weight, 0)
1046

1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            shard_offsets = [
                # (shard_id, shard_offset, shard_size)
                ("q", 0, self.total_num_heads * self.head_size),
                ("k", self.total_num_heads * self.head_size,
                 self.total_num_kv_heads * self.head_size),
                ("v", (self.total_num_heads + self.total_num_kv_heads) *
                 self.head_size, self.total_num_kv_heads * self.head_size),
            ]
1058
1059
1060
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)

1061
1062
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
1063
                # Special case for Quantized Weights.
1064
1065
1066
1067
1068
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
                    shard_size = shard_size // param.pack_factor
                    shard_offset = shard_offset // param.pack_factor
1069

1070
                    # Special case for Marlin.
1071
1072
1073
                    shard_size, shard_offset = adjust_marlin_shard(
                        param, shard_size, shard_offset)

1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
                if use_bitsandbytes_4bit:
                    orig_qkv_offsets = {
                        "q": (0, self.total_num_heads * self.head_size),
                        "k": (self.total_num_heads * self.head_size,
                              self.total_num_kv_heads * self.head_size),
                        "v":
                        ((self.total_num_heads + self.total_num_kv_heads) *
                         self.head_size,
                         self.total_num_kv_heads * self.head_size),
                        "total":
                        ((self.total_num_heads + 2 * self.total_num_kv_heads) *
                         self.head_size, 0)
                    }

                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
                        param, orig_qkv_offsets, shard_id)

1091
1092
1093
1094
1095
1096
1097
                loaded_weight_shard = loaded_weight.narrow(
                    output_dim, shard_offset, shard_size)
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        tp_rank = get_tensor_model_parallel_rank()
        assert loaded_shard_id in ["q", "k", "v"]
1098
1099

        # If output dim is defined, use the default loading process.
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
        if output_dim is not None:
            if loaded_shard_id == "q":
                shard_offset = 0
                shard_size = self.num_heads * self.head_size
            elif loaded_shard_id == "k":
                shard_offset = self.num_heads * self.head_size
                shard_size = self.num_kv_heads * self.head_size
            elif loaded_shard_id == "v":
                shard_offset = (self.num_heads +
                                self.num_kv_heads) * self.head_size
                shard_size = self.num_kv_heads * self.head_size
1111
            # Special case for Quantized Weights.
1112
1113
1114
1115
1116
1117
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
                shard_size = shard_size // param.pack_factor
                shard_offset = shard_offset // param.pack_factor
1118

1119
                # Special case for Marlin.
1120
1121
1122
                shard_size, shard_offset = adjust_marlin_shard(
                    param, shard_size, shard_offset)

1123
1124
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)
1125
1126
1127
1128
1129
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

1130
            if use_bitsandbytes_4bit:
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
                orig_qkv_offsets = {
                    "q": (0, self.num_heads * self.head_size),
                    "k": (self.num_heads * self.head_size,
                          self.num_kv_heads * self.head_size),
                    "v":
                    ((self.num_heads + self.num_kv_heads) * self.head_size,
                     self.num_kv_heads * self.head_size),
                    "total":
                    ((self.num_heads + 2 * self.num_kv_heads) * self.head_size,
                     0)
                }
1142
                shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
1143
                    param, orig_qkv_offsets, loaded_shard_id)
gaoqiong's avatar
gaoqiong committed
1144
1145

            param_data = param_data.narrow(output_dim, shard_offset,
zhuwenwen's avatar
zhuwenwen committed
1146
                                           shard_size)
zhuwenwen's avatar
zhuwenwen committed
1147
            if loaded_shard_id == "q":
1148
1149
1150
                shard_id = tp_rank
            else:
                shard_id = tp_rank // self.num_kv_head_replicas
1151
            start_idx = shard_id * shard_size
1152

1153
            if not is_sharded_weight:
1154
1155
1156
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)

1157
        # Special case for for AQLM codebooks.
James Fleming's avatar
James Fleming committed
1158
1159
1160
1161
1162
1163
        elif is_metadata:
            # metadata indicates fixed size concatenated along dim 0
            shard_size = loaded_weight.shape[0]
            shard_index = ["q", "k", "v"].index(loaded_shard_id)
            param_data = param_data.narrow(0, shard_index * shard_size,
                                           shard_size)
1164
1165
1166
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
1167
                param_data, loaded_weight, loaded_shard_id)
1168
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
1169
1170
1171
1172
1173
1174
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "QKVParallelLinear, assume the weight is the same "
                    "for all partitions.")
gaoqiong's avatar
gaoqiong committed
1175
1176
1177

        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)
1178
1179


1180
class RowParallelLinear(LinearBase):
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
    """Linear layer with row parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its first dimension and X along its second dimension as:
               -   -
              | A_1 |
              | .   |
          A = | .   |        X = [X_1, ..., X_p]
              | .   |
              | A_p |
               -   -
    Arguments:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias. Note that bias is not parallelized.
        input_is_parallel: If true, we assume that the input is already
                           split across the GPUs and we do not split
                           again.
        skip_bias_add: This was added to enable performance optimization where
                       bias can be fused with other element-wise operations.
                       We skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
1203
        quant_config: Quantization configure.
1204
1205
    """

1206
1207
1208
1209
1210
1211
1212
1213
    def __init__(self,
                 input_size: int,
                 output_size: int,
                 bias: bool = True,
                 input_is_parallel: bool = True,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
                 reduce_results: bool = True,
1214
                 quant_config: Optional[QuantizationConfig] = None,
1215
1216
                 prefix: str = "",
                 fuse_gemm_rs: bool = False):
1217
        super().__init__(input_size, output_size, skip_bias_add, params_dtype,
1218
                         quant_config, prefix, fuse_gemm_rs=fuse_gemm_rs)
1219

1220
1221
        self.input_is_parallel = input_is_parallel
        self.reduce_results = reduce_results
1222
1223
        if fuse_gemm_rs:
            self.reduce_results = False
1224
1225

        # Divide the weight matrix along the last dimension.
1226
        self.tp_rank = get_tensor_model_parallel_rank()
1227
1228
        self.tp_size = get_tensor_model_parallel_world_size()
        self.input_size_per_partition = divide(input_size, self.tp_size)
1229
        assert self.quant_method is not None
1230

1231
1232
1233
1234
1235
1236
1237
        self.quant_method.create_weights(
            layer=self,
            input_size_per_partition=self.input_size_per_partition,
            output_partition_sizes=[self.output_size],
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
1238
1239
            weight_loader=(
                self.weight_loader_v2 if self.quant_method.__class__.__name__
1240
                in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
1241
1242
1243
1244
1245
1246
        if not reduce_results and (bias and not skip_bias_add):
            raise ValueError("When not reduce the results, adding bias to the "
                             "results can lead to incorrect results")

        if bias:
            self.bias = Parameter(
1247
                torch.empty(self.output_size, dtype=params_dtype))
1248
1249
1250
1251
1252
1253
1254
1255
1256
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
        else:
            self.register_parameter("bias", None)

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        tp_rank = get_tensor_model_parallel_rank()
1257
        tp_size = get_tensor_model_parallel_world_size()
1258
        input_dim = getattr(param, "input_dim", None)
1259
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1260
1261
1262
1263
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
1264
1265
1266
1267
1268
1269
1270
1271
1272

        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
1273
1274
1275
1276
            weight_shape = list(loaded_weight.shape)
            if input_dim:
                weight_shape[input_dim] = weight_shape[input_dim] // tp_size
            param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
1277

1278
        param_data = param.data
1279
        if input_dim is not None and not is_sharded_weight:
1280
1281
1282
1283
            shard_size = param_data.shape[input_dim]
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(input_dim, start_idx,
                                                 shard_size)
1284

1285
1286
1287
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
1288
1289
            loaded_weight = loaded_weight.reshape(1)

1290
1291
1292
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

1293
1294
    def weight_loader_v2(self, param: BasevLLMParameter,
                         loaded_weight: torch.Tensor):
1295
1296
1297
1298
1299
1300
1301

        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)

1302
1303
        param.load_row_parallel_weight(loaded_weight=loaded_weight)

1304
    def forward(self, input_) -> tuple[torch.Tensor, Optional[Parameter]]:
1305
1306
1307
1308
1309
1310
1311
1312
1313
        if self.input_is_parallel:
            input_parallel = input_
        else:
            tp_rank = get_tensor_model_parallel_rank()
            splitted_input = split_tensor_along_last_dim(
                input_, num_partitions=self.tp_size)
            input_parallel = splitted_input[tp_rank].contiguous()

        # Matrix multiply.
1314
        assert self.quant_method is not None
1315
1316
1317
1318
1319
1320
        # Only fuse bias add into GEMM for rank 0 (this ensures that
        # bias will not get added more than once in TP>1 case)
        bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
        output_parallel = self.quant_method.apply(self,
                                                  input_parallel,
                                                  bias=bias_)
1321
        if self.reduce_results and self.tp_size > 1:
1322
            output = tensor_model_parallel_all_reduce(output_parallel)
1323
        else:
1324
1325
1326
            output = output_parallel

        output_bias = self.bias if self.skip_bias_add else None
1327
1328

        return output, output_bias
1329
1330
1331
1332
1333
1334
1335

    def extra_repr(self) -> str:
        s = f"input_features={self.input_size_per_partition}"
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={self.tp_size}"
        s += f", reduce_results={self.reduce_results}"
1336
        return s