linear.py 58.4 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import itertools
4
from abc import abstractmethod
5
from typing import Optional, List
6

zhuwenwen's avatar
zhuwenwen committed
7

8
9
import torch
import torch.nn.functional as F
10
from torch.nn.parameter import Parameter, UninitializedParameter
11

12
13
14
15
16
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size,
                              split_tensor_along_last_dim,
                              tensor_model_parallel_all_gather,
                              tensor_model_parallel_all_reduce)
zhuwenwen's avatar
zhuwenwen committed
17

18
from vllm.logger import init_logger
19
20
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig, QuantizeMethodBase)
21
# yapf: disable
22
from vllm.model_executor.parameter import (BasevLLMParameter,
23
                                           BlockQuantScaleParameter,
24
                                           PackedColumnParameter,
25
                                           PackedvLLMParameter,
26
27
                                           PerTensorScaleParameter,
                                           RowvLLMParameter)
28
# yapf: enable
29
from vllm.model_executor.utils import set_weight_attrs
gaoqiong's avatar
gaoqiong committed
30

zhuwenwen's avatar
zhuwenwen committed
31
import os
32
from vllm.model_executor.utils import gemm_bank_conf
zhuwenwen's avatar
zhuwenwen committed
33
34
35
36
37
38
import vllm.envs as envs

if envs.VLLM_USE_FLUX:
    import flux
    from vllm.distributed.parallel_state import get_tp_group
    
39
40
41

logger = init_logger(__name__)

42
WEIGHT_LOADER_V2_SUPPORTED = [
43
    "CompressedTensorsLinearMethod", "AWQMarlinLinearMethod",
44
    "AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod",
45
    "MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod",
46
    "TPUInt8LinearMethod", "GPTQLinearMethod", "FBGEMMFp8LinearMethod",
47
    "ModelOptFp8LinearMethod", "IPEXAWQLinearMethod", "IPEXGPTQLinearMethod",
48
    "HQQMarlinMethod", "QuarkLinearMethod", "BlockInt8LinearMethod",
49
]
50

51

52
53
54
55
56
57
58
59
def adjust_marlin_shard(param, shard_size, shard_offset):
    marlin_tile_size = getattr(param, "marlin_tile_size", None)
    if marlin_tile_size is None:
        return shard_size, shard_offset

    return shard_size * marlin_tile_size, shard_offset * marlin_tile_size


60
def adjust_bitsandbytes_4bit_shard(param: Parameter,
61
62
                                   shard_offsets: dict[str, tuple[int, int]],
                                   loaded_shard_id: str) -> tuple[int, int]:
63
64
    """Adjust the quantization offsets and sizes for BitsAndBytes sharding."""

65
66
    total, _ = shard_offsets["total"]
    orig_offset, orig_size = shard_offsets[loaded_shard_id]
67
68
69
70
71
72
73
74

    quantized_total = param.data.shape[0]
    quantized_offset = orig_offset * quantized_total // total
    quantized_size = orig_size * quantized_total // total

    return quantized_size, quantized_offset


75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
    """For fused modules (QKV and MLP) we have an array of length
    N that holds 1 scale for each "logical" matrix. So the param
    is an array of length N. The loaded_weight corresponds to 
    one of the shards on disk. Here, we slice the param based on 
    the shard_id for loading.
    """
    qkv_idxs = {"q": 0, "k": 1, "v": 2}

    if isinstance(shard_id, str):
        shard_id = qkv_idxs[shard_id]
    elif not isinstance(shard_id, int):
        raise ValueError(f"Unknown Shard Id {shard_id}")

    # AutoFP8 scales do not have a shape
    # compressed-tensors scales do have a shape
    if len(loaded_weight.shape) != 0:
        assert loaded_weight.shape[0] == 1
        loaded_weight = loaded_weight[0]

    return param[shard_id], loaded_weight


98
class LinearMethodBase(QuantizeMethodBase):
99
100
101
    """Base class for different (maybe quantized) linear methods."""

    @abstractmethod
102
103
    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
104
                       output_partition_sizes: list[int], input_size: int,
105
106
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
107
108
        """Create weights for a linear layer. 
           The weights will be set as attributes of the layer.
109

110
111
112
113
114
115
116
117
118
119
        Args:
            layer: The layer that is using the LinearMethodBase factory.
            input_size_per_partition: Size of the weight input dim on rank X.
            output_partition_sizes: Sizes of the output dim of each logical 
                weight on rank X. E.g., output_partition_sizes for QKVLinear
                is a list contains the width of Wq, Wk, Wv on rank X.
            input_size: Size of the input dim of the weight across all ranks.
            output_size: Size of the output dim of the weight across all ranks.
            params_dtype: Datatype of the parameters.
        """
120
121
122
        raise NotImplementedError

    @abstractmethod
123
124
125
126
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
127
128
        """Apply the weights in layer to the input tensor.
        Expects create_weights to have been called before on the layer."""
129
130
131
132
        raise NotImplementedError


class UnquantizedLinearMethod(LinearMethodBase):
133
    """Linear method without quantization."""
134
135
    
    def __init__(self):
zhuwenwen's avatar
zhuwenwen committed
136
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
137
        self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
gaoqiong's avatar
gaoqiong committed
138
        
139
140
    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
141
                       output_partition_sizes: list[int], input_size: int,
142
143
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
144
        weight = Parameter(torch.empty(sum(output_partition_sizes),
CHU Tianxiang's avatar
CHU Tianxiang committed
145
                                       input_size_per_partition,
146
147
148
                                       dtype=params_dtype),
                           requires_grad=False)
        set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
149
150
        layer.register_parameter("weight", weight)
        set_weight_attrs(weight, extra_weight_attrs)
151

152
153
154
155
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
156

zhuwenwen's avatar
zhuwenwen committed
157
        if self.use_llama_nn:
158
159
            if gemm_bank_conf(layer.weight.shape[1] - 32) and os.environ['GEMM_PAD'] == '1':
                layer.weight = layer.weight[:,:-32]
160
                
zhuwenwen's avatar
zhuwenwen committed
161
            if bias is not None:
zhuwenwen's avatar
zhuwenwen committed
162
                if len(x.shape) == 2: 
163
                    return torch.addmm(bias, x, layer.weight)
zhuwenwen's avatar
zhuwenwen committed
164
                else:
165
                    return torch.matmul(x, layer.weight) + bias
zhuwenwen's avatar
zhuwenwen committed
166
            else:
167
                return torch.matmul(x, layer.weight)
zhuwenwen's avatar
zhuwenwen committed
168
        else:
169
            return F.linear(x, layer.weight, bias)
170

171

172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
class GemmRS(LinearMethodBase):
    #Fused Gemm-ReduceScatter without quantization.

    def __init__(self, separate_bias_add: bool = False):
        self.separate_bias_add = separate_bias_add
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'

    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
                       output_partition_sizes: List[int], input_size: int,
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
        weight = Parameter(torch.empty(sum(output_partition_sizes),
                                       input_size_per_partition,
                                       dtype=params_dtype),
                           requires_grad=False)
        set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
        layer.register_parameter("weight", weight)
        set_weight_attrs(weight, extra_weight_attrs)

        if self.use_llama_nn:
            self.gemm_rs_op = flux.GemmRS(
                get_tp_group().device_group,
zhuwenwen's avatar
zhuwenwen committed
195
196
197
                1,  # One node
                8192,  # Max M. TODO: Pass in correctly.
                output_size,  # N
198
199
200
201
                # TODO: Pass in input dtype correctly.
                # TODO: It would be nicer to modify flux to dispatch based on dtype
                # at run time, but I don't know what the downside would be.
                # Similar comment for max m.
zhuwenwen's avatar
zhuwenwen committed
202
                params_dtype, # torch.float16,
203
204
205
206
207
208
209
                # Note: transpose_weight=False means that B is transposed
                transpose_weight=True,
                # Note: bfloat16 requires fuse_reduction=False.
                fuse_reduction=False,
            )
        else:
            self.gemm_rs_op = flux.GemmRS(
zhuwenwen's avatar
zhuwenwen committed
210
211
212
213
214
215
216
217
218
219
220
221
222
223
            get_tp_group().device_group,
            1,  # One node
            8192,  # Max M. TODO: Pass in correctly.
            output_size,  # N
            # TODO: Pass in input dtype correctly.
            # TODO: It would be nicer to modify flux to dispatch based on dtype
            # at run time, but I don't know what the downside would be.
            # Similar comment for max m.
            params_dtype, # torch.float16,
            # Note: transpose_weight=False means that B is transposed
            transpose_weight=False,
            # Note: bfloat16 requires fuse_reduction=False.
            fuse_reduction=False,
        )
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259

    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
        assert bias is None

        output = self.gemm_rs_op.forward(x, layer.weight)
        output = output.squeeze(0)

        return output


class AGCook(LinearMethodBase):
    #Fused AllGather-Gemm without quantization.

    def __init__(self, separate_bias_add: bool = False):
        self.separate_bias_add = separate_bias_add
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'

    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
                       output_partition_sizes: List[int], input_size: int,
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
        weight = Parameter(torch.empty(sum(output_partition_sizes),
                                       input_size_per_partition,
                                       dtype=params_dtype),
                           requires_grad=False)
        set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
        layer.register_parameter("weight", weight)
        set_weight_attrs(weight, extra_weight_attrs)

        if self.use_llama_nn:
            self.ag_gemm_op = flux.AGKernel(
                get_tp_group().device_group,
zhuwenwen's avatar
zhuwenwen committed
260
261
262
263
                1,  # One node
                8192,  # Max M. TODO: Pass in correctly.
                weight.shape[0],  # N
                weight.shape[1],  # K
264
265
266
267
                # TODO: Pass in input dtype correctly.
                # TODO: It would be nicer to modify flux to dispatch based on dtype
                # at run time, but I don't know what the downside would be.
                # Similar comment for max m.
zhuwenwen's avatar
zhuwenwen committed
268
269
                params_dtype, # torch.float16,
                params_dtype, # torch.float16,
270
271
272
273
274
275
276
277
278
279
                # Note: transpose_weight=False means that B is transposed
                transpose_weight=True,
                # Note: if local_copy=True, I hit the following runtime error:
                # /flux/src/all_gather/ths_op/all_gather_gemm_kernel.cc:648
                #   Check failed: 33554432((input.numel() * input.element_size()))
                #                 == 139836453421056((this->chunk_size))
                local_copy=False,
            )
        else:
            self.ag_gemm_op = flux.AGKernel(
zhuwenwen's avatar
zhuwenwen committed
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
            get_tp_group().device_group,
            1,  # One node
            8192,  # Max M. TODO: Pass in correctly.
            weight.shape[0],  # N
            weight.shape[1],  # K
            # TODO: Pass in input dtype correctly.
            # TODO: It would be nicer to modify flux to dispatch based on dtype
            # at run time, but I don't know what the downside would be.
            # Similar comment for max m.
            params_dtype, # torch.float16,
            params_dtype, # torch.float16,
            # Note: transpose_weight=False means that B is transposed
            transpose_weight=False,
            # Note: if local_copy=True, I hit the following runtime error:
            # /flux/src/all_gather/ths_op/all_gather_gemm_kernel.cc:648
            #   Check failed: 33554432((input.numel() * input.element_size()))
            #                 == 139836453421056((this->chunk_size))
            local_copy=False,
        )
299
300
301
302
303
304
305
306
307
308
309

    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
        assert bias is None

        output = self.ag_gemm_op.forward(x, layer.weight)

        return output
    
310
311
class LinearBase(torch.nn.Module):
    """Base linear layer.
312
313
314
315
316
317
318

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        bias: If true, add bias.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
319
        quant_config: Quantization configure.
320
321
322
323
324
325
326
327
    """

    def __init__(
        self,
        input_size: int,
        output_size: int,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
328
        quant_config: Optional[QuantizationConfig] = None,
329
        prefix: str = "",
330
331
        fuse_gemm_rs: bool = False,
        fuse_ag_gemm: bool = False,
332
333
334
335
336
337
338
339
340
341
    ):
        super().__init__()

        # Keep input parameters
        self.input_size = input_size
        self.output_size = output_size
        self.skip_bias_add = skip_bias_add
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype
342
343
344
345
346
347
348
349
        if fuse_gemm_rs:
            assert (quant_config is None)
            self.quant_method: Optional[QuantizeMethodBase] = GemmRS()
        elif fuse_ag_gemm:
            assert (quant_config is None)
            self.quant_method = AGCook()
        elif quant_config is None:
            self.quant_method = UnquantizedLinearMethod()
350
        else:
351
352
            self.quant_method = quant_config.get_quant_method(self,
                                                              prefix=prefix)
353

354
355
    def forward(self,
                x: torch.Tensor) -> tuple[torch.Tensor, Optional[Parameter]]:
356
357
358
359
360
361
362
363
364
365
366
367
368
        raise NotImplementedError


class ReplicatedLinear(LinearBase):
    """Replicated linear layer.

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        bias: If true, add bias.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
        quant_config: Quantization configure.
369
370
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
371
372
    """

373
374
375
376
377
378
    def __init__(self,
                 input_size: int,
                 output_size: int,
                 bias: bool = True,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
379
                 quant_config: Optional[QuantizationConfig] = None,
380
381
382
383
384
385
386
                 prefix: str = ""):
        super().__init__(input_size,
                         output_size,
                         skip_bias_add,
                         params_dtype,
                         quant_config,
                         prefix=prefix)
387

388
389
        # All the linear layer supports quant method.
        assert self.quant_method is not None
390
391
392
393
394
        self.quant_method.create_weights(self,
                                         self.input_size, [self.output_size],
                                         self.input_size,
                                         self.output_size,
                                         self.params_dtype,
395
                                         weight_loader=self.weight_loader)
396

397
398
        if bias:
            self.bias = Parameter(
399
                torch.empty(self.output_size, dtype=self.params_dtype))
400
401
402
403
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
404
405
406
        else:
            self.register_parameter("bias", None)

407
408
409
410
411
412
413
414
415
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        # If the weight on disk does not have a shape, give it one
        # (such scales for AutoFp8).
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)

        assert param.size() == loaded_weight.size()
        param.data.copy_(loaded_weight)

416
417
    def forward(self,
                x: torch.Tensor) -> tuple[torch.Tensor, Optional[Parameter]]:
418
        bias = self.bias if not self.skip_bias_add else None
419
        assert self.quant_method is not None
420
        output = self.quant_method.apply(self, x, bias)
421
422
423
        output_bias = self.bias if self.skip_bias_add else None
        return output, output_bias

424
425
426
427
428
429
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        return s

430

431
class ColumnParallelLinear(LinearBase):
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
    """Linear layer with column parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its second dimension as A = [A_1, ..., A_p].

    Args:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y_i = XA_i
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
448
        quant_config: Quantization configure.
James Fleming's avatar
James Fleming committed
449
450
        output_sizes: list of output sizes packed into one output, like for QKV
                       the list would be size 3.
451
452
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj) 
453
454
    """

455
456
457
458
459
460
461
462
    def __init__(self,
                 input_size: int,
                 output_size: int,
                 bias: bool = True,
                 gather_output: bool = False,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
                 quant_config: Optional[QuantizationConfig] = None,
463
                 output_sizes: Optional[list[int]] = None,
464
465
                 prefix: str = "",
                 fuse_ag_gemm: bool = False):
466
        super().__init__(input_size, output_size, skip_bias_add, params_dtype,
467
                         quant_config, prefix, fuse_ag_gemm=fuse_ag_gemm)
468
469

        self.gather_output = gather_output
470

471
472
        # Divide the weight matrix along the last dimension.
        tp_size = get_tensor_model_parallel_world_size()
473
474
475
476
477
478
479
480
481
482
        assert self.quant_method is not None
        self.output_size_per_partition = divide(self.output_size, tp_size)
        self.output_partition_sizes = [self.output_size_per_partition]
        # If QKV or MergedColumn, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = [
                divide(output_size, tp_size)
                for output_size in self.output_sizes
            ]

James Fleming's avatar
James Fleming committed
483
484
        if output_sizes is None:
            output_sizes = [output_size]
485

486
487
488
489
490
491
492
        self.quant_method.create_weights(
            layer=self,
            input_size_per_partition=self.input_size,
            output_partition_sizes=self.output_partition_sizes,
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
493
494
            weight_loader=(
                self.weight_loader_v2 if self.quant_method.__class__.__name__
495
                in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
496
497
498
499
500
501
502
503
504
505
506
507
508
509
        if bias:
            self.bias = Parameter(
                torch.empty(self.output_size_per_partition,
                            dtype=params_dtype))
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
        else:
            self.register_parameter("bias", None)

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        tp_rank = get_tensor_model_parallel_rank()
        output_dim = getattr(param, "output_dim", None)
510
511
512
513
514
515
516
517
518
519
520

        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
            param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)

521
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
522
523
524
525
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
526

527
        param_data = param.data
528
        if output_dim is not None and not is_sharded_weight:
529
530
531
532
            shard_size = param_data.shape[output_dim]
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                 shard_size)
533
534
535
536
537

        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)
538

539
540
541
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

542
    def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor):
543
544
545
546
547
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)
548
549
        param.load_column_parallel_weight(loaded_weight=loaded_weight)

550
    def forward(self, input_) -> tuple[torch.Tensor, Optional[Parameter]]:
551
552
553
        bias = self.bias if not self.skip_bias_add else None

        # Matrix multiply.
554
        assert self.quant_method is not None
555
        output_parallel = self.quant_method.apply(self, input_, bias)
556
557
558
559
560
561
562
563
        if self.gather_output:
            # All-gather across the partitions.
            output = tensor_model_parallel_all_gather(output_parallel)
        else:
            output = output_parallel
        output_bias = self.bias if self.skip_bias_add else None
        return output, output_bias

564
565
566
567
568
569
570
571
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size_per_partition}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={get_tensor_model_parallel_world_size()}"
        s += f", gather_output={self.gather_output}"
        return s

572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590

class MergedColumnParallelLinear(ColumnParallelLinear):
    """Packed linear layers with column parallelism.

    Similar to ColumnParallelLinear, but the weight matrix is concatenated
    along the output dimension. When the weight matrix is loaded, the
    different partitions are sharded separately.

    Args:
        input_size: input dimension of the linear layer.
        output_sizes: list of output dimensions of the linear layer.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make the output
                       available to all GPUs, otherwise, every GPU will have
                       its own output.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
591
        quant_config: Quantization configure.
592
593
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
594
595
    """

596
597
    def __init__(self,
                 input_size: int,
598
                 output_sizes: list[int],
599
600
601
602
                 bias: bool = True,
                 gather_output: bool = False,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
603
                 quant_config: Optional[QuantizationConfig] = None,
604
605
                 prefix: str = "",
                 fuse_ag_gemm: bool = False):
606
607
608
        self.output_sizes = output_sizes
        tp_size = get_tensor_model_parallel_world_size()
        assert all(output_size % tp_size == 0 for output_size in output_sizes)
609
610
611
612
613
614
        super().__init__(input_size=input_size,
                         output_size=sum(output_sizes),
                         bias=bias,
                         gather_output=gather_output,
                         skip_bias_add=skip_bias_add,
                         params_dtype=params_dtype,
615
                         quant_config=quant_config,
616
617
                         prefix=prefix,
                         fuse_ag_gemm=fuse_ag_gemm)
618
619
620
621
622

    def weight_loader(self,
                      param: Parameter,
                      loaded_weight: torch.Tensor,
                      loaded_shard_id: Optional[int] = None):
James Fleming's avatar
James Fleming committed
623

624
625
626
627
628
        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
629
630
631
632
633
634
635
636
            if loaded_shard_id is not None:
                param.data[loaded_shard_id].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
                param.shard_weight_type = {
                    i: loaded_weight.item()
                    for i, _ in enumerate(self.output_sizes)
                }
637
638
            return

639
640
641
        if is_gguf_weight:
            tp_size = get_tensor_model_parallel_world_size()
            tp_rank = get_tensor_model_parallel_rank()
642

643
644
645
            output_dim = getattr(param, "output_dim", None)
            shard_size = loaded_weight.size(output_dim) // tp_size
            start_idx = tp_rank * shard_size
646

647
648
649
650
651
652
653
654
655
            if loaded_shard_id is not None:
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                if len(param.data_container) == 2:
                    self.qweight = param.materialize_nested()
                return
656

657
658
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
659
        # Special case for AQLM codebooks.
James Fleming's avatar
James Fleming committed
660
        is_metadata = getattr(param, "is_metadata", False)
661
662
        # Special case for per-tensor scale to load scalar into fused array.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
663

664
        if loaded_shard_id is None:
665
666
            # Loaded weight is already fused on disk (mlp).
            # (e.g., Phi-3's gate_up_proj).
667
            if output_dim is None:
668
                if needs_scalar_to_array:
669
670
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
                        param_data, loaded_weight, 0)
671

672
673
674
675
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            current_shard_offset = 0
676
677
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)
678
            shard_offsets: list[tuple[int, int, int]] = []
679
680
681
682
683
            for i, output_size in enumerate(self.output_sizes):
                shard_offsets.append((i, current_shard_offset, output_size))
                current_shard_offset += output_size
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
684
                # Special case for Quantization.
685
686
687
688
689
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
                    shard_size = shard_size // param.pack_factor
                    shard_offset = shard_offset // param.pack_factor
690
                    # Special case for Marlin.
691
692
693
                    shard_size, shard_offset = adjust_marlin_shard(
                        param, shard_size, shard_offset)

694
                if use_bitsandbytes_4bit:
695
696
697
698
699
700
701
702
703
                    index = list(itertools.accumulate([0] + self.output_sizes))
                    orig_offsets = {
                        str(i): (index[i], size)
                        for i, size in enumerate(self.output_sizes)
                    }
                    orig_offsets["total"] = (self.output_size, 0)
                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
                        param, orig_offsets, str(shard_id))

704
705
706
707
708
709
710
711
712
713
714
                loaded_weight_shard = loaded_weight.narrow(
                    output_dim, shard_offset, shard_size)
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id < len(self.output_sizes)
        tp_rank = get_tensor_model_parallel_rank()
        tp_size = get_tensor_model_parallel_world_size()
        if output_dim is not None:
            shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
            shard_size = self.output_sizes[loaded_shard_id] // tp_size
715
            # Special case for quantization.
716
717
718
719
720
721
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
                shard_size = shard_size // param.pack_factor
                shard_offset = shard_offset // param.pack_factor
722
                # Special case for Marlin.
723
724
                shard_size, shard_offset = adjust_marlin_shard(
                    param, shard_size, shard_offset)
gaoqiong's avatar
gaoqiong committed
725

726
727
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)
728
729
730
731
732
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

733
            if use_bitsandbytes_4bit:
734
735
736
                shard_size = loaded_weight.shape[output_dim]
                shard_offset = loaded_weight.shape[output_dim] * \
                    loaded_shard_id
737

gaoqiong's avatar
gaoqiong committed
738
739
            param_data = param_data.narrow(output_dim, shard_offset,
                                           shard_size)
740
            start_idx = tp_rank * shard_size
741
            if not is_sharded_weight:
742
743
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)
744
        # Special case for AQLM codebooks.
James Fleming's avatar
James Fleming committed
745
746
747
748
749
        elif is_metadata:
            # metadata indicates fixed size concatenated along dim 0
            shard_size = loaded_weight.shape[0]
            shard_offset = loaded_shard_id * shard_size
            param_data = param_data.narrow(0, shard_offset, shard_size)
750

751
752
753
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
754
755
                param_data, loaded_weight, loaded_shard_id)

756
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
757
758
759
760
761
762
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "MergedColumnParallelLinear, assume the weight is "
                    "the same for all partitions.")
763

gaoqiong's avatar
gaoqiong committed
764
765
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)
766

767
768
769
770
771
772
773
774
775
776
777
778
779
    def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
                                           loaded_weight: torch.Tensor):
        """
        Handle special case for models where MLP layers are already
        fused on disk. In this case, we have no shard id. This function
        determmines the shard id by splitting these layers and then calls
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """

        current_shard_offset = 0
780
        shard_offsets: list[tuple[int, int, int]] = []
781
782
783
784
785
786
787
788
        for i, output_size in enumerate(self.output_sizes):
            shard_offsets.append((i, current_shard_offset, output_size))
            current_shard_offset += output_size

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
789
790
            if isinstance(param, (PackedColumnParameter, PackedvLLMParameter
                                  )) and param.packed_dim == param.output_dim:
791
792
                shard_size, shard_offset = \
                    param.adjust_shard_indexes_for_packing(
793
794
795
796
797
798
799
800
801
802
803
804
                    shard_size=shard_size, shard_offset=shard_offset)

            loaded_weight_shard = loaded_weight.narrow(param.output_dim,
                                                       shard_offset,
                                                       shard_size)
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

    def weight_loader_v2(self,
                         param: BasevLLMParameter,
                         loaded_weight: torch.Tensor,
                         loaded_shard_id: Optional[int] = None):
        if loaded_shard_id is None:
805
806
807
808
            if isinstance(param, PerTensorScaleParameter):
                param.load_merged_column_weight(loaded_weight=loaded_weight,
                                                shard_id=0)
                return
809
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
810
                param.load_merged_column_weight(loaded_weight=loaded_weight)
811
                return
812
            # TODO: @dsikka - move to parameter.py
813
814
815
816
817
818
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id < len(self.output_sizes)

        tp_size = get_tensor_model_parallel_world_size()
819
820
821
822

        if isinstance(param, BlockQuantScaleParameter):
            from vllm.model_executor.layers.quantization.fp8 import (
                Fp8LinearMethod, Fp8MoEMethod)
823
824
825
            
            from vllm.model_executor.layers.quantization.blockwise_int8 import (
                BlockInt8LinearMethod, BlockInt8MoEMethod)
826
827
            assert self.quant_method is not None
            assert isinstance(self.quant_method,
828
                              (Fp8LinearMethod, Fp8MoEMethod, BlockInt8LinearMethod, BlockInt8MoEMethod))
829
830
831
832
833
834
835
836
837
838
839
            weight_block_size = self.quant_method.quant_config.weight_block_size
            assert weight_block_size is not None
            block_n, _ = weight_block_size[0], weight_block_size[1]
            shard_offset = (
                (sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) //
                block_n) // tp_size
            shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) //
                          block_n // tp_size)
        else:
            shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
            shard_size = self.output_sizes[loaded_shard_id] // tp_size
840
841
842
843
844
845

        param.load_merged_column_weight(loaded_weight=loaded_weight,
                                        shard_id=loaded_shard_id,
                                        shard_offset=shard_offset,
                                        shard_size=shard_size)

846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867

class QKVParallelLinear(ColumnParallelLinear):
    """Linear layers for the attention's QKV transformation.

    Linear layers for the linear transformation of the query, key, and value
    vectors in the attention layer. The weight matrix is concatenated along
    the output dimension. The layer is parallelized along the head dimension.
    When the number of key/value heads is smaller than the number of query
    heads (e.g., multi-query/grouped-query attention), the key/value head may
    be replicated while the query heads are partitioned.

    Args:
        hidden_size: input hidden state size of the transformer.
        head_size: size of each attention head.
        total_num_heads: total number of attention query heads.
        total_num_kv_heads: total number of attention key/value heads. If
                            None, assume total_num_kv_heads = total_num_heads.
        bias: If true, add bias.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
868
        quant_config: Quantization configure.
869
870
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
871
872
    """

873
874
875
876
877
878
879
880
    def __init__(self,
                 hidden_size: int,
                 head_size: int,
                 total_num_heads: int,
                 total_num_kv_heads: Optional[int] = None,
                 bias: bool = True,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
881
                 quant_config: Optional[QuantizationConfig] = None,
882
883
                 prefix: str = "",
                 fuse_ag_gemm: bool = False):
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
        self.hidden_size = hidden_size
        self.head_size = head_size
        self.total_num_heads = total_num_heads
        if total_num_kv_heads is None:
            total_num_kv_heads = total_num_heads
        self.total_num_kv_heads = total_num_kv_heads
        # Divide the weight matrix along the last dimension.
        tp_size = get_tensor_model_parallel_world_size()
        self.num_heads = divide(self.total_num_heads, tp_size)
        if tp_size >= self.total_num_kv_heads:
            self.num_kv_heads = 1
            self.num_kv_head_replicas = divide(tp_size,
                                               self.total_num_kv_heads)
        else:
            self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
            self.num_kv_head_replicas = 1
        input_size = self.hidden_size
        output_size = (self.num_heads +
                       2 * self.num_kv_heads) * tp_size * self.head_size
903
904
905
906
        self.output_sizes = [
            self.num_heads * self.head_size * tp_size,  # q_proj
            self.num_kv_heads * self.head_size * tp_size,  # k_proj
            self.num_kv_heads * self.head_size * tp_size,  # v_proj 
James Fleming's avatar
James Fleming committed
907
        ]
gaoqiong's avatar
gaoqiong committed
908

909
910
911
912
913
914
        super().__init__(input_size=input_size,
                         output_size=output_size,
                         bias=bias,
                         gather_output=False,
                         skip_bias_add=skip_bias_add,
                         params_dtype=params_dtype,
915
                         quant_config=quant_config,
916
917
                         prefix=prefix,
                         fuse_ag_gemm=fuse_ag_gemm)
918

919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
    def _get_shard_offset_mapping(self, loaded_shard_id: str):
        shard_offset_mapping = {
            "q": 0,
            "k": self.num_heads * self.head_size,
            "v": (self.num_heads + self.num_kv_heads) * self.head_size,
            "total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size
        }
        return shard_offset_mapping.get(loaded_shard_id)

    def _get_shard_size_mapping(self, loaded_shard_id: str):
        shard_size_mapping = {
            "q": self.num_heads * self.head_size,
            "k": self.num_kv_heads * self.head_size,
            "v": self.num_kv_heads * self.head_size,
        }
        return shard_size_mapping.get(loaded_shard_id)

    def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
                                           loaded_weight: torch.Tensor):
        """
        Handle special case for models where QKV layers are already 
        fused on disk. In this case, we have no shard id. This function
        determmines the shard id by splitting these layers and then calls
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """
        shard_offsets = [
            # (shard_id, shard_offset, shard_size)
            ("q", 0, self.total_num_heads * self.head_size),
            ("k", self.total_num_heads * self.head_size,
             self.total_num_kv_heads * self.head_size),
            ("v",
             (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
             self.total_num_kv_heads * self.head_size),
        ]

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
961
962
            if isinstance(param, (PackedColumnParameter, PackedvLLMParameter
                                  )) and param.packed_dim == param.output_dim:
963
964
                shard_size, shard_offset = \
                    param.adjust_shard_indexes_for_packing(
965
966
967
968
969
970
971
972
973
974
975
976
                    shard_size=shard_size, shard_offset=shard_offset)

            loaded_weight_shard = loaded_weight.narrow(param.output_dim,
                                                       shard_offset,
                                                       shard_size)
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

    def weight_loader_v2(self,
                         param: BasevLLMParameter,
                         loaded_weight: torch.Tensor,
                         loaded_shard_id: Optional[str] = None):
        if loaded_shard_id is None:  # special case for certain models
977
            if isinstance(param, PerTensorScaleParameter):
978
                param.load_qkv_weight(loaded_weight=loaded_weight, shard_id=0)
979
                return
980
981
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
                param.load_qkv_weight(loaded_weight=loaded_weight)
982
                return
983
            # TODO: @dsikka - move to parameter.py
984
985
986
987
988
989
990
991
992
993
994
995
996
997
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id in ["q", "k", "v"]

        shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
        shard_size = self._get_shard_size_mapping(loaded_shard_id)

        param.load_qkv_weight(loaded_weight=loaded_weight,
                              num_heads=self.num_kv_head_replicas,
                              shard_id=loaded_shard_id,
                              shard_offset=shard_offset,
                              shard_size=shard_size)

998
999
1000
1001
    def weight_loader(self,
                      param: Parameter,
                      loaded_weight: torch.Tensor,
                      loaded_shard_id: Optional[str] = None):
1002
1003
1004
1005
1006

        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
1007
        if is_gguf_weight_type:
1008
            idx_map = {"q": 0, "k": 1, "v": 2}
1009
1010
1011
1012
1013
1014
1015
1016
            if loaded_shard_id is not None:
                param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
                param.shard_weight_type = {
                    k: loaded_weight.item()
                    for k in idx_map
                }
1017
1018
            return

1019
1020
1021
        if is_gguf_weight:
            tp_size = get_tensor_model_parallel_world_size()
            tp_rank = get_tensor_model_parallel_rank()
1022

1023
1024
1025
1026
            output_dim = getattr(param, "output_dim", None)
            shard_size = loaded_weight.size(output_dim) // tp_size
            start_idx = tp_rank * shard_size

1027
1028
1029
1030
1031
1032
1033
1034
1035
            if loaded_shard_id is not None:
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                if len(param.data_container) == 3:
                    self.qweight = param.materialize_nested()
                return
1036

1037
1038
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
1039
        # Special case for AQLM codebooks.
James Fleming's avatar
James Fleming committed
1040
        is_metadata = getattr(param, "is_metadata", False)
1041

1042
1043
        # Special case for per-tensor scales in fused case.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
1044

1045
        if loaded_shard_id is None:
1046
1047
            # Loaded weight is already fused on disk (qkv).
            # (e.g., Phi-3's qkv_proj).
1048
            if output_dim is None:
1049
                if needs_scalar_to_array:
1050
1051
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
                        param_data, loaded_weight, 0)
1052

1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            shard_offsets = [
                # (shard_id, shard_offset, shard_size)
                ("q", 0, self.total_num_heads * self.head_size),
                ("k", self.total_num_heads * self.head_size,
                 self.total_num_kv_heads * self.head_size),
                ("v", (self.total_num_heads + self.total_num_kv_heads) *
                 self.head_size, self.total_num_kv_heads * self.head_size),
            ]
1064
1065
1066
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)

1067
1068
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
1069
                # Special case for Quantized Weights.
1070
1071
1072
1073
1074
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
                    shard_size = shard_size // param.pack_factor
                    shard_offset = shard_offset // param.pack_factor
1075

1076
                    # Special case for Marlin.
1077
1078
1079
                    shard_size, shard_offset = adjust_marlin_shard(
                        param, shard_size, shard_offset)

1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
                if use_bitsandbytes_4bit:
                    orig_qkv_offsets = {
                        "q": (0, self.total_num_heads * self.head_size),
                        "k": (self.total_num_heads * self.head_size,
                              self.total_num_kv_heads * self.head_size),
                        "v":
                        ((self.total_num_heads + self.total_num_kv_heads) *
                         self.head_size,
                         self.total_num_kv_heads * self.head_size),
                        "total":
                        ((self.total_num_heads + 2 * self.total_num_kv_heads) *
                         self.head_size, 0)
                    }

                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
                        param, orig_qkv_offsets, shard_id)

1097
1098
1099
1100
1101
1102
1103
                loaded_weight_shard = loaded_weight.narrow(
                    output_dim, shard_offset, shard_size)
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        tp_rank = get_tensor_model_parallel_rank()
        assert loaded_shard_id in ["q", "k", "v"]
1104
1105

        # If output dim is defined, use the default loading process.
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
        if output_dim is not None:
            if loaded_shard_id == "q":
                shard_offset = 0
                shard_size = self.num_heads * self.head_size
            elif loaded_shard_id == "k":
                shard_offset = self.num_heads * self.head_size
                shard_size = self.num_kv_heads * self.head_size
            elif loaded_shard_id == "v":
                shard_offset = (self.num_heads +
                                self.num_kv_heads) * self.head_size
                shard_size = self.num_kv_heads * self.head_size
1117
            # Special case for Quantized Weights.
1118
1119
1120
1121
1122
1123
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
                shard_size = shard_size // param.pack_factor
                shard_offset = shard_offset // param.pack_factor
1124

1125
                # Special case for Marlin.
1126
1127
1128
                shard_size, shard_offset = adjust_marlin_shard(
                    param, shard_size, shard_offset)

1129
1130
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)
1131
1132
1133
1134
1135
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

1136
            if use_bitsandbytes_4bit:
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
                orig_qkv_offsets = {
                    "q": (0, self.num_heads * self.head_size),
                    "k": (self.num_heads * self.head_size,
                          self.num_kv_heads * self.head_size),
                    "v":
                    ((self.num_heads + self.num_kv_heads) * self.head_size,
                     self.num_kv_heads * self.head_size),
                    "total":
                    ((self.num_heads + 2 * self.num_kv_heads) * self.head_size,
                     0)
                }
1148
                shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
1149
                    param, orig_qkv_offsets, loaded_shard_id)
gaoqiong's avatar
gaoqiong committed
1150
1151

            param_data = param_data.narrow(output_dim, shard_offset,
zhuwenwen's avatar
zhuwenwen committed
1152
                                           shard_size)
zhuwenwen's avatar
zhuwenwen committed
1153
            if loaded_shard_id == "q":
1154
1155
1156
                shard_id = tp_rank
            else:
                shard_id = tp_rank // self.num_kv_head_replicas
1157
            start_idx = shard_id * shard_size
1158

1159
            if not is_sharded_weight:
1160
1161
1162
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)

1163
        # Special case for for AQLM codebooks.
James Fleming's avatar
James Fleming committed
1164
1165
1166
1167
1168
1169
        elif is_metadata:
            # metadata indicates fixed size concatenated along dim 0
            shard_size = loaded_weight.shape[0]
            shard_index = ["q", "k", "v"].index(loaded_shard_id)
            param_data = param_data.narrow(0, shard_index * shard_size,
                                           shard_size)
1170
1171
1172
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
1173
                param_data, loaded_weight, loaded_shard_id)
1174
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
1175
1176
1177
1178
1179
1180
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "QKVParallelLinear, assume the weight is the same "
                    "for all partitions.")
gaoqiong's avatar
gaoqiong committed
1181
1182
1183

        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)
1184
1185


1186
class RowParallelLinear(LinearBase):
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
    """Linear layer with row parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its first dimension and X along its second dimension as:
               -   -
              | A_1 |
              | .   |
          A = | .   |        X = [X_1, ..., X_p]
              | .   |
              | A_p |
               -   -
    Arguments:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias. Note that bias is not parallelized.
        input_is_parallel: If true, we assume that the input is already
                           split across the GPUs and we do not split
                           again.
        skip_bias_add: This was added to enable performance optimization where
                       bias can be fused with other element-wise operations.
                       We skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
1209
        quant_config: Quantization configure.
1210
1211
    """

1212
1213
1214
1215
1216
1217
1218
1219
    def __init__(self,
                 input_size: int,
                 output_size: int,
                 bias: bool = True,
                 input_is_parallel: bool = True,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
                 reduce_results: bool = True,
1220
                 quant_config: Optional[QuantizationConfig] = None,
1221
1222
                 prefix: str = "",
                 fuse_gemm_rs: bool = False):
1223
        super().__init__(input_size, output_size, skip_bias_add, params_dtype,
1224
                         quant_config, prefix, fuse_gemm_rs=fuse_gemm_rs)
1225

1226
1227
        self.input_is_parallel = input_is_parallel
        self.reduce_results = reduce_results
1228
1229
        if fuse_gemm_rs:
            self.reduce_results = False
1230
1231

        # Divide the weight matrix along the last dimension.
1232
        self.tp_rank = get_tensor_model_parallel_rank()
1233
1234
        self.tp_size = get_tensor_model_parallel_world_size()
        self.input_size_per_partition = divide(input_size, self.tp_size)
1235
        assert self.quant_method is not None
1236

1237
1238
1239
1240
1241
1242
1243
        self.quant_method.create_weights(
            layer=self,
            input_size_per_partition=self.input_size_per_partition,
            output_partition_sizes=[self.output_size],
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
1244
1245
            weight_loader=(
                self.weight_loader_v2 if self.quant_method.__class__.__name__
1246
                in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
1247
1248
1249
1250
1251
1252
        if not reduce_results and (bias and not skip_bias_add):
            raise ValueError("When not reduce the results, adding bias to the "
                             "results can lead to incorrect results")

        if bias:
            self.bias = Parameter(
1253
                torch.empty(self.output_size, dtype=params_dtype))
1254
1255
1256
1257
1258
1259
1260
1261
1262
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
        else:
            self.register_parameter("bias", None)

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        tp_rank = get_tensor_model_parallel_rank()
1263
        tp_size = get_tensor_model_parallel_world_size()
1264
        input_dim = getattr(param, "input_dim", None)
1265
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1266
1267
1268
1269
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
1270
1271
1272
1273
1274
1275
1276
1277
1278

        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
1279
1280
1281
1282
            weight_shape = list(loaded_weight.shape)
            if input_dim:
                weight_shape[input_dim] = weight_shape[input_dim] // tp_size
            param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
1283

1284
        param_data = param.data
1285
        if input_dim is not None and not is_sharded_weight:
1286
1287
1288
1289
            shard_size = param_data.shape[input_dim]
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(input_dim, start_idx,
                                                 shard_size)
1290

1291
1292
1293
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
1294
1295
            loaded_weight = loaded_weight.reshape(1)

1296
1297
1298
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

1299
1300
    def weight_loader_v2(self, param: BasevLLMParameter,
                         loaded_weight: torch.Tensor):
1301
1302
1303
1304
1305
1306
1307

        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)

1308
1309
        param.load_row_parallel_weight(loaded_weight=loaded_weight)

1310
    def forward(self, input_) -> tuple[torch.Tensor, Optional[Parameter]]:
1311
1312
1313
1314
1315
1316
1317
1318
1319
        if self.input_is_parallel:
            input_parallel = input_
        else:
            tp_rank = get_tensor_model_parallel_rank()
            splitted_input = split_tensor_along_last_dim(
                input_, num_partitions=self.tp_size)
            input_parallel = splitted_input[tp_rank].contiguous()

        # Matrix multiply.
1320
        assert self.quant_method is not None
1321
1322
1323
1324
1325
1326
        # Only fuse bias add into GEMM for rank 0 (this ensures that
        # bias will not get added more than once in TP>1 case)
        bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
        output_parallel = self.quant_method.apply(self,
                                                  input_parallel,
                                                  bias=bias_)
1327
        if self.reduce_results and self.tp_size > 1:
1328
            output = tensor_model_parallel_all_reduce(output_parallel)
1329
        else:
1330
1331
1332
            output = output_parallel

        output_bias = self.bias if self.skip_bias_add else None
1333
1334

        return output, output_bias
1335
1336
1337
1338
1339
1340
1341

    def extra_repr(self) -> str:
        s = f"input_features={self.input_size_per_partition}"
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={self.tp_size}"
        s += f", reduce_results={self.reduce_results}"
1342
        return s