linear.py 50.7 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import itertools
4
from abc import abstractmethod
5
from typing import Optional
6
7
8

import torch
import torch.nn.functional as F
9
from torch.nn.parameter import Parameter, UninitializedParameter
10

11
12
13
14
15
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size,
                              split_tensor_along_last_dim,
                              tensor_model_parallel_all_gather,
                              tensor_model_parallel_all_reduce)
16
from vllm.logger import init_logger
17
18
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig, QuantizeMethodBase)
19
# yapf: disable
20
from vllm.model_executor.parameter import (BasevLLMParameter,
21
                                           BlockQuantScaleParameter,
22
                                           PackedColumnParameter,
23
                                           PackedvLLMParameter,
24
25
                                           PerTensorScaleParameter,
                                           RowvLLMParameter)
26
# yapf: enable
27
28
29
30
from vllm.model_executor.utils import set_weight_attrs

logger = init_logger(__name__)

31
WEIGHT_LOADER_V2_SUPPORTED = [
32
    "CompressedTensorsLinearMethod", "AWQMarlinLinearMethod",
33
    "AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod",
34
    "MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod",
35
    "TPUInt8LinearMethod", "GPTQLinearMethod", "FBGEMMFp8LinearMethod",
36
    "ModelOptFp8LinearMethod", "IPEXAWQLinearMethod", "IPEXGPTQLinearMethod",
37
    "HQQMarlinMethod", "QuarkLinearMethod"
38
]
39

40

41
42
43
44
45
46
47
48
def adjust_marlin_shard(param, shard_size, shard_offset):
    marlin_tile_size = getattr(param, "marlin_tile_size", None)
    if marlin_tile_size is None:
        return shard_size, shard_offset

    return shard_size * marlin_tile_size, shard_offset * marlin_tile_size


49
def adjust_bitsandbytes_4bit_shard(param: Parameter,
50
51
                                   shard_offsets: dict[str, tuple[int, int]],
                                   loaded_shard_id: str) -> tuple[int, int]:
52
53
    """Adjust the quantization offsets and sizes for BitsAndBytes sharding."""

54
55
    total, _ = shard_offsets["total"]
    orig_offset, orig_size = shard_offsets[loaded_shard_id]
56
57
58
59
60
61
62
63

    quantized_total = param.data.shape[0]
    quantized_offset = orig_offset * quantized_total // total
    quantized_size = orig_size * quantized_total // total

    return quantized_size, quantized_offset


64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
    """For fused modules (QKV and MLP) we have an array of length
    N that holds 1 scale for each "logical" matrix. So the param
    is an array of length N. The loaded_weight corresponds to 
    one of the shards on disk. Here, we slice the param based on 
    the shard_id for loading.
    """
    qkv_idxs = {"q": 0, "k": 1, "v": 2}

    if isinstance(shard_id, str):
        shard_id = qkv_idxs[shard_id]
    elif not isinstance(shard_id, int):
        raise ValueError(f"Unknown Shard Id {shard_id}")

    # AutoFP8 scales do not have a shape
    # compressed-tensors scales do have a shape
    if len(loaded_weight.shape) != 0:
        assert loaded_weight.shape[0] == 1
        loaded_weight = loaded_weight[0]

    return param[shard_id], loaded_weight


87
class LinearMethodBase(QuantizeMethodBase):
88
89
90
    """Base class for different (maybe quantized) linear methods."""

    @abstractmethod
91
92
    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
93
                       output_partition_sizes: list[int], input_size: int,
94
95
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
96
97
        """Create weights for a linear layer. 
           The weights will be set as attributes of the layer.
98

99
100
101
102
103
104
105
106
107
108
        Args:
            layer: The layer that is using the LinearMethodBase factory.
            input_size_per_partition: Size of the weight input dim on rank X.
            output_partition_sizes: Sizes of the output dim of each logical 
                weight on rank X. E.g., output_partition_sizes for QKVLinear
                is a list contains the width of Wq, Wk, Wv on rank X.
            input_size: Size of the input dim of the weight across all ranks.
            output_size: Size of the output dim of the weight across all ranks.
            params_dtype: Datatype of the parameters.
        """
109
110
111
        raise NotImplementedError

    @abstractmethod
112
113
114
115
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
116
117
        """Apply the weights in layer to the input tensor.
        Expects create_weights to have been called before on the layer."""
118
119
120
121
        raise NotImplementedError


class UnquantizedLinearMethod(LinearMethodBase):
122
    """Linear method without quantization."""
123

124
125
    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
126
                       output_partition_sizes: list[int], input_size: int,
127
128
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
129
        weight = Parameter(torch.empty(sum(output_partition_sizes),
CHU Tianxiang's avatar
CHU Tianxiang committed
130
                                       input_size_per_partition,
131
132
133
                                       dtype=params_dtype),
                           requires_grad=False)
        set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
134
135
        layer.register_parameter("weight", weight)
        set_weight_attrs(weight, extra_weight_attrs)
136

137
138
139
140
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
141
142

        return F.linear(x, layer.weight, bias)
143
144


145
146
class LinearBase(torch.nn.Module):
    """Base linear layer.
147
148
149
150
151
152
153

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        bias: If true, add bias.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
154
        quant_config: Quantization configure.
155
156
157
158
159
160
161
162
    """

    def __init__(
        self,
        input_size: int,
        output_size: int,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
163
        quant_config: Optional[QuantizationConfig] = None,
164
        prefix: str = "",
165
166
167
168
169
170
171
172
173
174
    ):
        super().__init__()

        # Keep input parameters
        self.input_size = input_size
        self.output_size = output_size
        self.skip_bias_add = skip_bias_add
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype
175
        if quant_config is None:
176
177
            self.quant_method: Optional[
                QuantizeMethodBase] = UnquantizedLinearMethod()
178
        else:
179
180
            self.quant_method = quant_config.get_quant_method(self,
                                                              prefix=prefix)
181

182
183
    def forward(self,
                x: torch.Tensor) -> tuple[torch.Tensor, Optional[Parameter]]:
184
185
186
187
188
189
190
191
192
193
194
195
196
        raise NotImplementedError


class ReplicatedLinear(LinearBase):
    """Replicated linear layer.

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        bias: If true, add bias.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
        quant_config: Quantization configure.
197
198
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
199
200
    """

201
202
203
204
205
206
    def __init__(self,
                 input_size: int,
                 output_size: int,
                 bias: bool = True,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
207
                 quant_config: Optional[QuantizationConfig] = None,
208
209
210
211
212
213
214
                 prefix: str = ""):
        super().__init__(input_size,
                         output_size,
                         skip_bias_add,
                         params_dtype,
                         quant_config,
                         prefix=prefix)
215

216
217
        # All the linear layer supports quant method.
        assert self.quant_method is not None
218
219
220
221
222
        self.quant_method.create_weights(self,
                                         self.input_size, [self.output_size],
                                         self.input_size,
                                         self.output_size,
                                         self.params_dtype,
223
                                         weight_loader=self.weight_loader)
224

225
226
        if bias:
            self.bias = Parameter(
227
                torch.empty(self.output_size, dtype=self.params_dtype))
228
229
230
231
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
232
233
234
        else:
            self.register_parameter("bias", None)

235
236
237
238
239
240
241
242
243
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        # If the weight on disk does not have a shape, give it one
        # (such scales for AutoFp8).
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)

        assert param.size() == loaded_weight.size()
        param.data.copy_(loaded_weight)

244
245
    def forward(self,
                x: torch.Tensor) -> tuple[torch.Tensor, Optional[Parameter]]:
246
        bias = self.bias if not self.skip_bias_add else None
247
        assert self.quant_method is not None
248
        output = self.quant_method.apply(self, x, bias)
249
250
251
        output_bias = self.bias if self.skip_bias_add else None
        return output, output_bias

252
253
254
255
256
257
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        return s

258

259
class ColumnParallelLinear(LinearBase):
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
    """Linear layer with column parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its second dimension as A = [A_1, ..., A_p].

    Args:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y_i = XA_i
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
276
        quant_config: Quantization configure.
James Fleming's avatar
James Fleming committed
277
278
        output_sizes: list of output sizes packed into one output, like for QKV
                       the list would be size 3.
279
280
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj) 
281
282
    """

283
284
285
286
287
288
289
290
    def __init__(self,
                 input_size: int,
                 output_size: int,
                 bias: bool = True,
                 gather_output: bool = False,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
                 quant_config: Optional[QuantizationConfig] = None,
291
                 output_sizes: Optional[list[int]] = None,
292
                 prefix: str = ""):
293
        super().__init__(input_size, output_size, skip_bias_add, params_dtype,
294
                         quant_config, prefix)
295
296

        self.gather_output = gather_output
297

298
299
        # Divide the weight matrix along the last dimension.
        tp_size = get_tensor_model_parallel_world_size()
300
301
302
303
304
305
306
307
308
309
        assert self.quant_method is not None
        self.output_size_per_partition = divide(self.output_size, tp_size)
        self.output_partition_sizes = [self.output_size_per_partition]
        # If QKV or MergedColumn, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = [
                divide(output_size, tp_size)
                for output_size in self.output_sizes
            ]

James Fleming's avatar
James Fleming committed
310
311
        if output_sizes is None:
            output_sizes = [output_size]
312

313
314
315
316
317
318
319
        self.quant_method.create_weights(
            layer=self,
            input_size_per_partition=self.input_size,
            output_partition_sizes=self.output_partition_sizes,
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
320
321
            weight_loader=(
                self.weight_loader_v2 if self.quant_method.__class__.__name__
322
                in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
323
324
325
326
327
328
329
330
331
332
333
334
335
336
        if bias:
            self.bias = Parameter(
                torch.empty(self.output_size_per_partition,
                            dtype=params_dtype))
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
        else:
            self.register_parameter("bias", None)

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        tp_rank = get_tensor_model_parallel_rank()
        output_dim = getattr(param, "output_dim", None)
337
338
339
340
341
342
343
344
345
346
347

        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
            param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)

348
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
349
350
351
352
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
353

354
        param_data = param.data
355
        if output_dim is not None and not is_sharded_weight:
356
357
358
359
            shard_size = param_data.shape[output_dim]
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                 shard_size)
360
361
362
363
364

        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)
365

366
367
368
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

369
    def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor):
370
371
372
373
374
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)
375
376
        param.load_column_parallel_weight(loaded_weight=loaded_weight)

377
    def forward(self, input_) -> tuple[torch.Tensor, Optional[Parameter]]:
378
379
380
        bias = self.bias if not self.skip_bias_add else None

        # Matrix multiply.
381
        assert self.quant_method is not None
382
        output_parallel = self.quant_method.apply(self, input_, bias)
383
384
385
386
387
388
389
390
        if self.gather_output:
            # All-gather across the partitions.
            output = tensor_model_parallel_all_gather(output_parallel)
        else:
            output = output_parallel
        output_bias = self.bias if self.skip_bias_add else None
        return output, output_bias

391
392
393
394
395
396
397
398
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size_per_partition}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={get_tensor_model_parallel_world_size()}"
        s += f", gather_output={self.gather_output}"
        return s

399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417

class MergedColumnParallelLinear(ColumnParallelLinear):
    """Packed linear layers with column parallelism.

    Similar to ColumnParallelLinear, but the weight matrix is concatenated
    along the output dimension. When the weight matrix is loaded, the
    different partitions are sharded separately.

    Args:
        input_size: input dimension of the linear layer.
        output_sizes: list of output dimensions of the linear layer.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make the output
                       available to all GPUs, otherwise, every GPU will have
                       its own output.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
418
        quant_config: Quantization configure.
419
420
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
421
422
    """

423
424
    def __init__(self,
                 input_size: int,
425
                 output_sizes: list[int],
426
427
428
429
                 bias: bool = True,
                 gather_output: bool = False,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
430
                 quant_config: Optional[QuantizationConfig] = None,
431
                 prefix: str = ""):
432
433
434
        self.output_sizes = output_sizes
        tp_size = get_tensor_model_parallel_world_size()
        assert all(output_size % tp_size == 0 for output_size in output_sizes)
435
436
437
438
439
440
        super().__init__(input_size=input_size,
                         output_size=sum(output_sizes),
                         bias=bias,
                         gather_output=gather_output,
                         skip_bias_add=skip_bias_add,
                         params_dtype=params_dtype,
441
442
                         quant_config=quant_config,
                         prefix=prefix)
443
444
445
446
447

    def weight_loader(self,
                      param: Parameter,
                      loaded_weight: torch.Tensor,
                      loaded_shard_id: Optional[int] = None):
James Fleming's avatar
James Fleming committed
448

449
450
451
452
453
        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
454
455
456
457
458
459
460
461
            if loaded_shard_id is not None:
                param.data[loaded_shard_id].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
                param.shard_weight_type = {
                    i: loaded_weight.item()
                    for i, _ in enumerate(self.output_sizes)
                }
462
463
            return

464
465
466
467
468
469
470
        if is_gguf_weight:
            tp_size = get_tensor_model_parallel_world_size()
            tp_rank = get_tensor_model_parallel_rank()

            output_dim = getattr(param, "output_dim", None)
            shard_size = loaded_weight.size(output_dim) // tp_size
            start_idx = tp_rank * shard_size
471

472
473
474
475
476
477
478
479
480
            if loaded_shard_id is not None:
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                if len(param.data_container) == 2:
                    self.qweight = param.materialize_nested()
                return
481

482
483
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
484
        # Special case for AQLM codebooks.
James Fleming's avatar
James Fleming committed
485
        is_metadata = getattr(param, "is_metadata", False)
486
487
        # Special case for per-tensor scale to load scalar into fused array.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
488

489
        if loaded_shard_id is None:
490
491
            # Loaded weight is already fused on disk (mlp).
            # (e.g., Phi-3's gate_up_proj).
492
            if output_dim is None:
493
                if needs_scalar_to_array:
494
495
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
                        param_data, loaded_weight, 0)
496

497
498
499
500
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            current_shard_offset = 0
501
502
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)
503
            shard_offsets: list[tuple[int, int, int]] = []
504
505
506
507
508
            for i, output_size in enumerate(self.output_sizes):
                shard_offsets.append((i, current_shard_offset, output_size))
                current_shard_offset += output_size
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
509
                # Special case for Quantization.
510
511
512
513
514
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
                    shard_size = shard_size // param.pack_factor
                    shard_offset = shard_offset // param.pack_factor
515
                    # Special case for Marlin.
516
517
                    shard_size, shard_offset = adjust_marlin_shard(
                        param, shard_size, shard_offset)
518

519
                if use_bitsandbytes_4bit:
520
521
522
523
524
525
526
527
528
                    index = list(itertools.accumulate([0] + self.output_sizes))
                    orig_offsets = {
                        str(i): (index[i], size)
                        for i, size in enumerate(self.output_sizes)
                    }
                    orig_offsets["total"] = (self.output_size, 0)
                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
                        param, orig_offsets, str(shard_id))

529
530
531
532
533
534
535
536
537
538
539
                loaded_weight_shard = loaded_weight.narrow(
                    output_dim, shard_offset, shard_size)
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id < len(self.output_sizes)
        tp_rank = get_tensor_model_parallel_rank()
        tp_size = get_tensor_model_parallel_world_size()
        if output_dim is not None:
            shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
            shard_size = self.output_sizes[loaded_shard_id] // tp_size
540
            # Special case for quantization.
541
542
543
544
545
546
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
                shard_size = shard_size // param.pack_factor
                shard_offset = shard_offset // param.pack_factor
547
                # Special case for Marlin.
548
549
550
                shard_size, shard_offset = adjust_marlin_shard(
                    param, shard_size, shard_offset)

551
552
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)
553
554
555
556
557
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

558
            if use_bitsandbytes_4bit:
559
560
561
562
                shard_size = loaded_weight.shape[output_dim]
                shard_offset = loaded_weight.shape[output_dim] * \
                    loaded_shard_id

563
564
565
            param_data = param_data.narrow(output_dim, shard_offset,
                                           shard_size)
            start_idx = tp_rank * shard_size
566
            if not is_sharded_weight:
567
568
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)
569
        # Special case for AQLM codebooks.
James Fleming's avatar
James Fleming committed
570
571
572
573
574
        elif is_metadata:
            # metadata indicates fixed size concatenated along dim 0
            shard_size = loaded_weight.shape[0]
            shard_offset = loaded_shard_id * shard_size
            param_data = param_data.narrow(0, shard_offset, shard_size)
575

576
577
578
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
579
580
                param_data, loaded_weight, loaded_shard_id)

581
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
582
583
584
585
586
587
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "MergedColumnParallelLinear, assume the weight is "
                    "the same for all partitions.")
588

589
590
591
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

592
593
594
595
596
597
598
599
600
601
602
603
604
    def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
                                           loaded_weight: torch.Tensor):
        """
        Handle special case for models where MLP layers are already
        fused on disk. In this case, we have no shard id. This function
        determmines the shard id by splitting these layers and then calls
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """

        current_shard_offset = 0
605
        shard_offsets: list[tuple[int, int, int]] = []
606
607
608
609
610
611
612
613
        for i, output_size in enumerate(self.output_sizes):
            shard_offsets.append((i, current_shard_offset, output_size))
            current_shard_offset += output_size

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
614
615
            if isinstance(param, (PackedColumnParameter, PackedvLLMParameter
                                  )) and param.packed_dim == param.output_dim:
616
617
                shard_size, shard_offset = \
                    param.adjust_shard_indexes_for_packing(
618
619
620
621
622
623
624
625
626
627
628
629
                    shard_size=shard_size, shard_offset=shard_offset)

            loaded_weight_shard = loaded_weight.narrow(param.output_dim,
                                                       shard_offset,
                                                       shard_size)
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

    def weight_loader_v2(self,
                         param: BasevLLMParameter,
                         loaded_weight: torch.Tensor,
                         loaded_shard_id: Optional[int] = None):
        if loaded_shard_id is None:
630
631
632
633
            if isinstance(param, PerTensorScaleParameter):
                param.load_merged_column_weight(loaded_weight=loaded_weight,
                                                shard_id=0)
                return
634
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
635
                param.load_merged_column_weight(loaded_weight=loaded_weight)
636
                return
637
            # TODO: @dsikka - move to parameter.py
638
639
640
641
642
643
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id < len(self.output_sizes)

        tp_size = get_tensor_model_parallel_world_size()
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661

        if isinstance(param, BlockQuantScaleParameter):
            from vllm.model_executor.layers.quantization.fp8 import (
                Fp8LinearMethod, Fp8MoEMethod)
            assert self.quant_method is not None
            assert isinstance(self.quant_method,
                              (Fp8LinearMethod, Fp8MoEMethod))
            weight_block_size = self.quant_method.quant_config.weight_block_size
            assert weight_block_size is not None
            block_n, _ = weight_block_size[0], weight_block_size[1]
            shard_offset = (
                (sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) //
                block_n) // tp_size
            shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) //
                          block_n // tp_size)
        else:
            shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
            shard_size = self.output_sizes[loaded_shard_id] // tp_size
662
663
664
665
666
667

        param.load_merged_column_weight(loaded_weight=loaded_weight,
                                        shard_id=loaded_shard_id,
                                        shard_offset=shard_offset,
                                        shard_size=shard_size)

668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689

class QKVParallelLinear(ColumnParallelLinear):
    """Linear layers for the attention's QKV transformation.

    Linear layers for the linear transformation of the query, key, and value
    vectors in the attention layer. The weight matrix is concatenated along
    the output dimension. The layer is parallelized along the head dimension.
    When the number of key/value heads is smaller than the number of query
    heads (e.g., multi-query/grouped-query attention), the key/value head may
    be replicated while the query heads are partitioned.

    Args:
        hidden_size: input hidden state size of the transformer.
        head_size: size of each attention head.
        total_num_heads: total number of attention query heads.
        total_num_kv_heads: total number of attention key/value heads. If
                            None, assume total_num_kv_heads = total_num_heads.
        bias: If true, add bias.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
690
        quant_config: Quantization configure.
691
692
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
693
694
    """

695
696
697
698
699
700
701
702
    def __init__(self,
                 hidden_size: int,
                 head_size: int,
                 total_num_heads: int,
                 total_num_kv_heads: Optional[int] = None,
                 bias: bool = True,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
703
                 quant_config: Optional[QuantizationConfig] = None,
704
                 prefix: str = ""):
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
        self.hidden_size = hidden_size
        self.head_size = head_size
        self.total_num_heads = total_num_heads
        if total_num_kv_heads is None:
            total_num_kv_heads = total_num_heads
        self.total_num_kv_heads = total_num_kv_heads
        # Divide the weight matrix along the last dimension.
        tp_size = get_tensor_model_parallel_world_size()
        self.num_heads = divide(self.total_num_heads, tp_size)
        if tp_size >= self.total_num_kv_heads:
            self.num_kv_heads = 1
            self.num_kv_head_replicas = divide(tp_size,
                                               self.total_num_kv_heads)
        else:
            self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
            self.num_kv_head_replicas = 1
        input_size = self.hidden_size
        output_size = (self.num_heads +
                       2 * self.num_kv_heads) * tp_size * self.head_size
724
725
726
727
        self.output_sizes = [
            self.num_heads * self.head_size * tp_size,  # q_proj
            self.num_kv_heads * self.head_size * tp_size,  # k_proj
            self.num_kv_heads * self.head_size * tp_size,  # v_proj 
James Fleming's avatar
James Fleming committed
728
729
        ]

730
731
732
733
734
735
        super().__init__(input_size=input_size,
                         output_size=output_size,
                         bias=bias,
                         gather_output=False,
                         skip_bias_add=skip_bias_add,
                         params_dtype=params_dtype,
736
737
                         quant_config=quant_config,
                         prefix=prefix)
738

739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
    def _get_shard_offset_mapping(self, loaded_shard_id: str):
        shard_offset_mapping = {
            "q": 0,
            "k": self.num_heads * self.head_size,
            "v": (self.num_heads + self.num_kv_heads) * self.head_size,
            "total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size
        }
        return shard_offset_mapping.get(loaded_shard_id)

    def _get_shard_size_mapping(self, loaded_shard_id: str):
        shard_size_mapping = {
            "q": self.num_heads * self.head_size,
            "k": self.num_kv_heads * self.head_size,
            "v": self.num_kv_heads * self.head_size,
        }
        return shard_size_mapping.get(loaded_shard_id)

    def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
                                           loaded_weight: torch.Tensor):
        """
        Handle special case for models where QKV layers are already 
        fused on disk. In this case, we have no shard id. This function
        determmines the shard id by splitting these layers and then calls
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """
        shard_offsets = [
            # (shard_id, shard_offset, shard_size)
            ("q", 0, self.total_num_heads * self.head_size),
            ("k", self.total_num_heads * self.head_size,
             self.total_num_kv_heads * self.head_size),
            ("v",
             (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
             self.total_num_kv_heads * self.head_size),
        ]

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
781
782
            if isinstance(param, (PackedColumnParameter, PackedvLLMParameter
                                  )) and param.packed_dim == param.output_dim:
783
784
                shard_size, shard_offset = \
                    param.adjust_shard_indexes_for_packing(
785
786
787
788
789
790
791
792
793
794
795
796
                    shard_size=shard_size, shard_offset=shard_offset)

            loaded_weight_shard = loaded_weight.narrow(param.output_dim,
                                                       shard_offset,
                                                       shard_size)
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

    def weight_loader_v2(self,
                         param: BasevLLMParameter,
                         loaded_weight: torch.Tensor,
                         loaded_shard_id: Optional[str] = None):
        if loaded_shard_id is None:  # special case for certain models
797
            if isinstance(param, PerTensorScaleParameter):
798
                param.load_qkv_weight(loaded_weight=loaded_weight, shard_id=0)
799
                return
800
801
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
                param.load_qkv_weight(loaded_weight=loaded_weight)
802
                return
803
            # TODO: @dsikka - move to parameter.py
804
805
806
807
808
809
810
811
812
813
814
815
816
817
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id in ["q", "k", "v"]

        shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
        shard_size = self._get_shard_size_mapping(loaded_shard_id)

        param.load_qkv_weight(loaded_weight=loaded_weight,
                              num_heads=self.num_kv_head_replicas,
                              shard_id=loaded_shard_id,
                              shard_offset=shard_offset,
                              shard_size=shard_size)

818
819
820
821
    def weight_loader(self,
                      param: Parameter,
                      loaded_weight: torch.Tensor,
                      loaded_shard_id: Optional[str] = None):
822
823
824
825
826

        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
827
        if is_gguf_weight_type:
828
            idx_map = {"q": 0, "k": 1, "v": 2}
829
830
831
832
833
834
835
836
            if loaded_shard_id is not None:
                param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
                param.shard_weight_type = {
                    k: loaded_weight.item()
                    for k in idx_map
                }
837
838
            return

839
840
841
842
843
844
845
846
        if is_gguf_weight:
            tp_size = get_tensor_model_parallel_world_size()
            tp_rank = get_tensor_model_parallel_rank()

            output_dim = getattr(param, "output_dim", None)
            shard_size = loaded_weight.size(output_dim) // tp_size
            start_idx = tp_rank * shard_size

847
848
849
850
851
852
853
854
855
            if loaded_shard_id is not None:
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                if len(param.data_container) == 3:
                    self.qweight = param.materialize_nested()
                return
856

857
858
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
859
        # Special case for AQLM codebooks.
James Fleming's avatar
James Fleming committed
860
        is_metadata = getattr(param, "is_metadata", False)
861

862
863
        # Special case for per-tensor scales in fused case.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
864

865
        if loaded_shard_id is None:
866
867
            # Loaded weight is already fused on disk (qkv).
            # (e.g., Phi-3's qkv_proj).
868
            if output_dim is None:
869
                if needs_scalar_to_array:
870
871
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
                        param_data, loaded_weight, 0)
872

873
874
875
876
877
878
879
880
881
882
883
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            shard_offsets = [
                # (shard_id, shard_offset, shard_size)
                ("q", 0, self.total_num_heads * self.head_size),
                ("k", self.total_num_heads * self.head_size,
                 self.total_num_kv_heads * self.head_size),
                ("v", (self.total_num_heads + self.total_num_kv_heads) *
                 self.head_size, self.total_num_kv_heads * self.head_size),
            ]
884
885
886
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)

887
888
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
889
                # Special case for Quantized Weights.
890
891
892
893
894
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
                    shard_size = shard_size // param.pack_factor
                    shard_offset = shard_offset // param.pack_factor
895

896
                    # Special case for Marlin.
897
898
899
                    shard_size, shard_offset = adjust_marlin_shard(
                        param, shard_size, shard_offset)

900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
                if use_bitsandbytes_4bit:
                    orig_qkv_offsets = {
                        "q": (0, self.total_num_heads * self.head_size),
                        "k": (self.total_num_heads * self.head_size,
                              self.total_num_kv_heads * self.head_size),
                        "v":
                        ((self.total_num_heads + self.total_num_kv_heads) *
                         self.head_size,
                         self.total_num_kv_heads * self.head_size),
                        "total":
                        ((self.total_num_heads + 2 * self.total_num_kv_heads) *
                         self.head_size, 0)
                    }

                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
                        param, orig_qkv_offsets, shard_id)

917
918
919
920
921
922
923
                loaded_weight_shard = loaded_weight.narrow(
                    output_dim, shard_offset, shard_size)
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        tp_rank = get_tensor_model_parallel_rank()
        assert loaded_shard_id in ["q", "k", "v"]
924
925

        # If output dim is defined, use the default loading process.
926
927
928
929
930
931
932
933
934
935
936
        if output_dim is not None:
            if loaded_shard_id == "q":
                shard_offset = 0
                shard_size = self.num_heads * self.head_size
            elif loaded_shard_id == "k":
                shard_offset = self.num_heads * self.head_size
                shard_size = self.num_kv_heads * self.head_size
            elif loaded_shard_id == "v":
                shard_offset = (self.num_heads +
                                self.num_kv_heads) * self.head_size
                shard_size = self.num_kv_heads * self.head_size
937
            # Special case for Quantized Weights.
938
939
940
941
942
943
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
                shard_size = shard_size // param.pack_factor
                shard_offset = shard_offset // param.pack_factor
944

945
                # Special case for Marlin.
946
947
948
                shard_size, shard_offset = adjust_marlin_shard(
                    param, shard_size, shard_offset)

949
950
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)
951
952
953
954
955
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

956
            if use_bitsandbytes_4bit:
957
958
959
960
961
962
963
964
965
966
967
                orig_qkv_offsets = {
                    "q": (0, self.num_heads * self.head_size),
                    "k": (self.num_heads * self.head_size,
                          self.num_kv_heads * self.head_size),
                    "v":
                    ((self.num_heads + self.num_kv_heads) * self.head_size,
                     self.num_kv_heads * self.head_size),
                    "total":
                    ((self.num_heads + 2 * self.num_kv_heads) * self.head_size,
                     0)
                }
968
                shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
969
970
                    param, orig_qkv_offsets, loaded_shard_id)

971
972
            param_data = param_data.narrow(output_dim, shard_offset,
                                           shard_size)
973
974
975
976
            if loaded_shard_id == "q":
                shard_id = tp_rank
            else:
                shard_id = tp_rank // self.num_kv_head_replicas
977
            start_idx = shard_id * shard_size
978

979
            if not is_sharded_weight:
980
981
982
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)

983
        # Special case for for AQLM codebooks.
James Fleming's avatar
James Fleming committed
984
985
986
987
988
989
        elif is_metadata:
            # metadata indicates fixed size concatenated along dim 0
            shard_size = loaded_weight.shape[0]
            shard_index = ["q", "k", "v"].index(loaded_shard_id)
            param_data = param_data.narrow(0, shard_index * shard_size,
                                           shard_size)
990
991
992
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
993
                param_data, loaded_weight, loaded_shard_id)
994
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
995
996
997
998
999
1000
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "QKVParallelLinear, assume the weight is the same "
                    "for all partitions.")
1001

1002
1003
1004
1005
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)


1006
class RowParallelLinear(LinearBase):
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
    """Linear layer with row parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its first dimension and X along its second dimension as:
               -   -
              | A_1 |
              | .   |
          A = | .   |        X = [X_1, ..., X_p]
              | .   |
              | A_p |
               -   -
    Arguments:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias. Note that bias is not parallelized.
        input_is_parallel: If true, we assume that the input is already
                           split across the GPUs and we do not split
                           again.
        skip_bias_add: This was added to enable performance optimization where
                       bias can be fused with other element-wise operations.
                       We skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
1029
        quant_config: Quantization configure.
1030
1031
    """

1032
1033
1034
1035
1036
1037
1038
1039
    def __init__(self,
                 input_size: int,
                 output_size: int,
                 bias: bool = True,
                 input_is_parallel: bool = True,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
                 reduce_results: bool = True,
1040
                 quant_config: Optional[QuantizationConfig] = None,
1041
                 prefix: str = ""):
1042
        super().__init__(input_size, output_size, skip_bias_add, params_dtype,
1043
                         quant_config, prefix)
1044

1045
1046
1047
1048
        self.input_is_parallel = input_is_parallel
        self.reduce_results = reduce_results

        # Divide the weight matrix along the last dimension.
1049
        self.tp_rank = get_tensor_model_parallel_rank()
1050
1051
        self.tp_size = get_tensor_model_parallel_world_size()
        self.input_size_per_partition = divide(input_size, self.tp_size)
1052
        assert self.quant_method is not None
1053

1054
1055
1056
1057
1058
1059
1060
        self.quant_method.create_weights(
            layer=self,
            input_size_per_partition=self.input_size_per_partition,
            output_partition_sizes=[self.output_size],
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
1061
1062
            weight_loader=(
                self.weight_loader_v2 if self.quant_method.__class__.__name__
1063
                in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
1064
1065
1066
1067
1068
1069
        if not reduce_results and (bias and not skip_bias_add):
            raise ValueError("When not reduce the results, adding bias to the "
                             "results can lead to incorrect results")

        if bias:
            self.bias = Parameter(
1070
                torch.empty(self.output_size, dtype=params_dtype))
1071
1072
1073
1074
1075
1076
1077
1078
1079
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
        else:
            self.register_parameter("bias", None)

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        tp_rank = get_tensor_model_parallel_rank()
1080
        tp_size = get_tensor_model_parallel_world_size()
1081
        input_dim = getattr(param, "input_dim", None)
1082
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1083
1084
1085
1086
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
1087
1088
1089
1090
1091
1092
1093
1094
1095

        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
1096
1097
1098
1099
            weight_shape = list(loaded_weight.shape)
            if input_dim:
                weight_shape[input_dim] = weight_shape[input_dim] // tp_size
            param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
1100

1101
        param_data = param.data
1102
        if input_dim is not None and not is_sharded_weight:
1103
1104
1105
1106
            shard_size = param_data.shape[input_dim]
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(input_dim, start_idx,
                                                 shard_size)
1107

1108
1109
1110
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
1111
1112
            loaded_weight = loaded_weight.reshape(1)

1113
1114
1115
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

1116
1117
    def weight_loader_v2(self, param: BasevLLMParameter,
                         loaded_weight: torch.Tensor):
1118
1119
1120
1121
1122
1123
1124

        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)

1125
1126
        param.load_row_parallel_weight(loaded_weight=loaded_weight)

1127
    def forward(self, input_) -> tuple[torch.Tensor, Optional[Parameter]]:
1128
1129
1130
1131
1132
1133
1134
1135
1136
        if self.input_is_parallel:
            input_parallel = input_
        else:
            tp_rank = get_tensor_model_parallel_rank()
            splitted_input = split_tensor_along_last_dim(
                input_, num_partitions=self.tp_size)
            input_parallel = splitted_input[tp_rank].contiguous()

        # Matrix multiply.
1137
        assert self.quant_method is not None
1138
1139
1140
1141
1142
1143
        # Only fuse bias add into GEMM for rank 0 (this ensures that
        # bias will not get added more than once in TP>1 case)
        bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
        output_parallel = self.quant_method.apply(self,
                                                  input_parallel,
                                                  bias=bias_)
1144
        if self.reduce_results and self.tp_size > 1:
1145
            output = tensor_model_parallel_all_reduce(output_parallel)
1146
        else:
1147
1148
1149
            output = output_parallel

        output_bias = self.bias if self.skip_bias_add else None
1150
1151

        return output, output_bias
1152
1153
1154
1155
1156
1157
1158
1159

    def extra_repr(self) -> str:
        s = f"input_features={self.input_size_per_partition}"
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={self.tp_size}"
        s += f", reduce_results={self.reduce_results}"
        return s