linear.py 29.7 KB
Newer Older
1
from abc import abstractmethod
2
from typing import List, Optional
3
4
5
6
7

import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter

8
9
10
11
12
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size,
                              split_tensor_along_last_dim,
                              tensor_model_parallel_all_gather,
                              tensor_model_parallel_all_reduce)
13
from vllm.logger import init_logger
14
15
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig, QuantizeMethodBase)
16
17
18
19
20
from vllm.model_executor.utils import set_weight_attrs

logger = init_logger(__name__)


21
22
23
24
25
26
27
28
def adjust_marlin_shard(param, shard_size, shard_offset):
    marlin_tile_size = getattr(param, "marlin_tile_size", None)
    if marlin_tile_size is None:
        return shard_size, shard_offset

    return shard_size * marlin_tile_size, shard_offset * marlin_tile_size


29
class LinearMethodBase(QuantizeMethodBase):
30
31
32
    """Base class for different (maybe quantized) linear methods."""

    @abstractmethod
33
34
    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
James Fleming's avatar
James Fleming committed
35
                       output_partition_sizes: List[int], input_size: int,
36
37
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
38
39
40
41
42
43
44
45
46
47
48
49
50
        """Create weights for a linear layer. 
           The weights will be set as attributes of the layer.
        
        Args:
            layer: The layer that is using the LinearMethodBase factory.
            input_size_per_partition: Size of the weight input dim on rank X.
            output_partition_sizes: Sizes of the output dim of each logical 
                weight on rank X. E.g., output_partition_sizes for QKVLinear
                is a list contains the width of Wq, Wk, Wv on rank X.
            input_size: Size of the input dim of the weight across all ranks.
            output_size: Size of the output dim of the weight across all ranks.
            params_dtype: Datatype of the parameters.
        """
51
52
53
        raise NotImplementedError

    @abstractmethod
54
55
56
57
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
58
59
60
        """Apply the weights in layer to the input tensor.

        Expects create_weights to have been called before on the layer."""
61
62
63
64
65
66
67
68
69
70
71
72
73
74
        raise NotImplementedError


class UnquantizedLinearMethod(LinearMethodBase):
    """Linear method without quantization.

    Args:
        separate_bias_add: If true, add bias separately after matrix
                           multiplication.
    """

    def __init__(self, separate_bias_add: bool = False):
        self.separate_bias_add = separate_bias_add

75
76
    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
James Fleming's avatar
James Fleming committed
77
                       output_partition_sizes: List[int], input_size: int,
78
79
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
James Fleming's avatar
James Fleming committed
80
        output_size_per_partition = sum(output_partition_sizes)
CHU Tianxiang's avatar
CHU Tianxiang committed
81
82
        weight = Parameter(torch.empty(output_size_per_partition,
                                       input_size_per_partition,
83
84
85
                                       dtype=params_dtype),
                           requires_grad=False)
        set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
86
87
        layer.register_parameter("weight", weight)
        set_weight_attrs(weight, extra_weight_attrs)
88

89
90
91
92
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
93
        weight = layer.weight
94
        if self.separate_bias_add:
95
            if bias is not None:
96
97
98
99
100
                return F.linear(x, weight) + bias
            return F.linear(x, weight)
        return F.linear(x, weight, bias)


101
102
class LinearBase(torch.nn.Module):
    """Base linear layer.
103
104
105
106
107
108
109

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        bias: If true, add bias.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
110
        quant_config: Quantization configure.
111
112
113
114
115
116
117
118
    """

    def __init__(
        self,
        input_size: int,
        output_size: int,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
119
        quant_config: Optional[QuantizationConfig] = None,
120
121
122
123
124
125
126
127
128
129
    ):
        super().__init__()

        # Keep input parameters
        self.input_size = input_size
        self.output_size = output_size
        self.skip_bias_add = skip_bias_add
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype
130
        if quant_config is None:
131
132
            self.quant_method: Optional[
                QuantizeMethodBase] = UnquantizedLinearMethod()
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
        else:
            self.quant_method = quant_config.get_quant_method(self)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError


class ReplicatedLinear(LinearBase):
    """Replicated linear layer.

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        bias: If true, add bias.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
        quant_config: Quantization configure.
    """

    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
    ):
        super().__init__(input_size, output_size, skip_bias_add, params_dtype,
                         quant_config)

164
165
        # All the linear layer supports quant method.
        assert self.quant_method is not None
166
167
168
169
        self.quant_method.create_weights(self, self.input_size,
                                         [self.output_size], self.input_size,
                                         self.output_size, self.params_dtype)

170
171
        if bias:
            self.bias = Parameter(
172
                torch.empty(self.output_size, dtype=self.params_dtype))
173
174
175
176
177
178
            set_weight_attrs(self.bias, {"output_dim": 0})
        else:
            self.register_parameter("bias", None)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        bias = self.bias if not self.skip_bias_add else None
179
        assert self.quant_method is not None
180
        output = self.quant_method.apply(self, x, bias)
181
182
183
        output_bias = self.bias if self.skip_bias_add else None
        return output, output_bias

184
185
186
187
188
189
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        return s

190

191
class ColumnParallelLinear(LinearBase):
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
    """Linear layer with column parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its second dimension as A = [A_1, ..., A_p].

    Args:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y_i = XA_i
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
208
        quant_config: Quantization configure.
James Fleming's avatar
James Fleming committed
209
210
        output_sizes: list of output sizes packed into one output, like for QKV
                       the list would be size 3.
211
212
213
214
215
216
217
218
219
220
    """

    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
221
        quant_config: Optional[QuantizationConfig] = None,
James Fleming's avatar
James Fleming committed
222
        output_sizes: Optional[List[int]] = None,
223
    ):
224
225
        super().__init__(input_size, output_size, skip_bias_add, params_dtype,
                         quant_config)
226
227

        self.gather_output = gather_output
228

229
230
231
        # Divide the weight matrix along the last dimension.
        tp_size = get_tensor_model_parallel_world_size()
        self.output_size_per_partition = divide(output_size, tp_size)
James Fleming's avatar
James Fleming committed
232
233
        if output_sizes is None:
            output_sizes = [output_size]
234
235
        # All the linear layer supports quant method.
        assert self.quant_method is not None
236
237
238
239
240
241
242
        self.quant_method.create_weights(self,
                                         self.input_size,
                                         [x // tp_size for x in output_sizes],
                                         self.input_size,
                                         self.output_size,
                                         self.params_dtype,
                                         weight_loader=self.weight_loader)
243
244
245
246
247
248
249
250
251
252
253
254
        if bias:
            self.bias = Parameter(
                torch.empty(self.output_size_per_partition,
                            dtype=params_dtype))
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
        else:
            self.register_parameter("bias", None)

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
255
256
257
258
        # Special case for Fp8 scales.
        fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
                                           None)

259
260
261
262
263
264
265
266
        tp_rank = get_tensor_model_parallel_rank()
        output_dim = getattr(param, "output_dim", None)
        param_data = param.data
        if output_dim is not None:
            shard_size = param_data.shape[output_dim]
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                 shard_size)
267
268
269
270
271
272
        # Special case for Fp8 scales.
        elif fp8_scales_shard_indexer is not None:
            param_data, loaded_weight = fp8_scales_shard_indexer(param_data,
                                                                 loaded_weight,
                                                                 shard_id=0)

273
274
275
276
277
278
279
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

    def forward(self, input_):
        bias = self.bias if not self.skip_bias_add else None

        # Matrix multiply.
280
        assert self.quant_method is not None
281
        output_parallel = self.quant_method.apply(self, input_, bias)
282
283
284
285
286
287
288
289
        if self.gather_output:
            # All-gather across the partitions.
            output = tensor_model_parallel_all_gather(output_parallel)
        else:
            output = output_parallel
        output_bias = self.bias if self.skip_bias_add else None
        return output, output_bias

290
291
292
293
294
295
296
297
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size_per_partition}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={get_tensor_model_parallel_world_size()}"
        s += f", gather_output={self.gather_output}"
        return s

298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316

class MergedColumnParallelLinear(ColumnParallelLinear):
    """Packed linear layers with column parallelism.

    Similar to ColumnParallelLinear, but the weight matrix is concatenated
    along the output dimension. When the weight matrix is loaded, the
    different partitions are sharded separately.

    Args:
        input_size: input dimension of the linear layer.
        output_sizes: list of output dimensions of the linear layer.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make the output
                       available to all GPUs, otherwise, every GPU will have
                       its own output.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
317
        quant_config: Quantization configure.
318
319
320
321
322
323
324
325
326
327
    """

    def __init__(
        self,
        input_size: int,
        output_sizes: List[int],
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
328
        quant_config: Optional[QuantizationConfig] = None,
329
330
331
332
333
    ):
        self.output_sizes = output_sizes
        tp_size = get_tensor_model_parallel_world_size()
        assert all(output_size % tp_size == 0 for output_size in output_sizes)
        super().__init__(input_size, sum(output_sizes), bias, gather_output,
334
                         skip_bias_add, params_dtype, quant_config,
James Fleming's avatar
James Fleming committed
335
                         self.output_sizes)
336
337
338
339
340

    def weight_loader(self,
                      param: Parameter,
                      loaded_weight: torch.Tensor,
                      loaded_shard_id: Optional[int] = None):
James Fleming's avatar
James Fleming committed
341

342
343
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
344
        # Special case for AQLM codebooks.
James Fleming's avatar
James Fleming committed
345
        is_metadata = getattr(param, "is_metadata", False)
346
347
348
349
        # Special case for Fp8 scales.
        fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
                                           None)

350
351
352
353
354
355
356
357
358
359
360
361
362
        if loaded_shard_id is None:
            # Loaded weight is already packed.
            if output_dim is None:
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            current_shard_offset = 0
            shard_offsets = []
            for i, output_size in enumerate(self.output_sizes):
                shard_offsets.append((i, current_shard_offset, output_size))
                current_shard_offset += output_size
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
363
                # Special case for Quantization.
364
365
366
367
368
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
                    shard_size = shard_size // param.pack_factor
                    shard_offset = shard_offset // param.pack_factor
369
                    # Special case for Marlin.
370
371
372
                    shard_size, shard_offset = adjust_marlin_shard(
                        param, shard_size, shard_offset)

373
374
375
376
377
378
379
380
381
382
383
                loaded_weight_shard = loaded_weight.narrow(
                    output_dim, shard_offset, shard_size)
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id < len(self.output_sizes)
        tp_rank = get_tensor_model_parallel_rank()
        tp_size = get_tensor_model_parallel_world_size()
        if output_dim is not None:
            shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
            shard_size = self.output_sizes[loaded_shard_id] // tp_size
384
            # Special case for quantization.
385
386
387
388
389
390
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
                shard_size = shard_size // param.pack_factor
                shard_offset = shard_offset // param.pack_factor
391
                # Special case for Marlin.
392
393
394
                shard_size, shard_offset = adjust_marlin_shard(
                    param, shard_size, shard_offset)

395
396
397
398
399
            param_data = param_data.narrow(output_dim, shard_offset,
                                           shard_size)
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                 shard_size)
400
        # Special case for AQLM codebooks.
James Fleming's avatar
James Fleming committed
401
402
403
404
405
        elif is_metadata:
            # metadata indicates fixed size concatenated along dim 0
            shard_size = loaded_weight.shape[0]
            shard_offset = loaded_shard_id * shard_size
            param_data = param_data.narrow(0, shard_offset, shard_size)
406
407
408
409
410
        # Special case for Fp8 scales.
        elif fp8_scales_shard_indexer is not None:
            param_data, loaded_weight = fp8_scales_shard_indexer(
                param_data, loaded_weight, loaded_shard_id)

411
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
412
413
414
415
416
417
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "MergedColumnParallelLinear, assume the weight is "
                    "the same for all partitions.")
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)


class QKVParallelLinear(ColumnParallelLinear):
    """Linear layers for the attention's QKV transformation.

    Linear layers for the linear transformation of the query, key, and value
    vectors in the attention layer. The weight matrix is concatenated along
    the output dimension. The layer is parallelized along the head dimension.
    When the number of key/value heads is smaller than the number of query
    heads (e.g., multi-query/grouped-query attention), the key/value head may
    be replicated while the query heads are partitioned.

    Args:
        hidden_size: input hidden state size of the transformer.
        head_size: size of each attention head.
        total_num_heads: total number of attention query heads.
        total_num_kv_heads: total number of attention key/value heads. If
                            None, assume total_num_kv_heads = total_num_heads.
        bias: If true, add bias.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
443
        quant_config: Quantization configure.
444
445
446
447
448
449
450
451
452
453
454
    """

    def __init__(
        self,
        hidden_size: int,
        head_size: int,
        total_num_heads: int,
        total_num_kv_heads: Optional[int] = None,
        bias: bool = True,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
455
        quant_config: Optional[QuantizationConfig] = None,
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
    ):
        self.hidden_size = hidden_size
        self.head_size = head_size
        self.total_num_heads = total_num_heads
        if total_num_kv_heads is None:
            total_num_kv_heads = total_num_heads
        self.total_num_kv_heads = total_num_kv_heads
        # Divide the weight matrix along the last dimension.
        tp_size = get_tensor_model_parallel_world_size()
        self.num_heads = divide(self.total_num_heads, tp_size)
        if tp_size >= self.total_num_kv_heads:
            self.num_kv_heads = 1
            self.num_kv_head_replicas = divide(tp_size,
                                               self.total_num_kv_heads)
        else:
            self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
            self.num_kv_head_replicas = 1
        input_size = self.hidden_size
        output_size = (self.num_heads +
                       2 * self.num_kv_heads) * tp_size * self.head_size
James Fleming's avatar
James Fleming committed
476
477
478
479
480
481
        output_sizes = [
            self.num_heads * tp_size * self.head_size,
            self.num_kv_heads * tp_size * self.head_size,
            self.num_kv_heads * tp_size * self.head_size
        ]

482
        super().__init__(input_size, output_size, bias, False, skip_bias_add,
483
                         params_dtype, quant_config, output_sizes)
484
485
486
487
488
489
490

    def weight_loader(self,
                      param: Parameter,
                      loaded_weight: torch.Tensor,
                      loaded_shard_id: Optional[str] = None):
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
491
        # Special case for AQLM codebooks.
James Fleming's avatar
James Fleming committed
492
        is_metadata = getattr(param, "is_metadata", False)
493
494
495
        # Special case for Fp8 scales.
        fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
                                           None)
496

497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
        if loaded_shard_id is None:
            # Loaded weight is already packed.
            if output_dim is None:
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            shard_offsets = [
                # (shard_id, shard_offset, shard_size)
                ("q", 0, self.total_num_heads * self.head_size),
                ("k", self.total_num_heads * self.head_size,
                 self.total_num_kv_heads * self.head_size),
                ("v", (self.total_num_heads + self.total_num_kv_heads) *
                 self.head_size, self.total_num_kv_heads * self.head_size),
            ]
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
513
                # Special case for Quantized Weights.
514
515
516
517
518
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
                    shard_size = shard_size // param.pack_factor
                    shard_offset = shard_offset // param.pack_factor
519

520
                    # Special case for Marlin.
521
522
523
                    shard_size, shard_offset = adjust_marlin_shard(
                        param, shard_size, shard_offset)

524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
                loaded_weight_shard = loaded_weight.narrow(
                    output_dim, shard_offset, shard_size)
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        tp_rank = get_tensor_model_parallel_rank()
        assert loaded_shard_id in ["q", "k", "v"]
        if output_dim is not None:
            if loaded_shard_id == "q":
                shard_offset = 0
                shard_size = self.num_heads * self.head_size
            elif loaded_shard_id == "k":
                shard_offset = self.num_heads * self.head_size
                shard_size = self.num_kv_heads * self.head_size
            elif loaded_shard_id == "v":
                shard_offset = (self.num_heads +
                                self.num_kv_heads) * self.head_size
                shard_size = self.num_kv_heads * self.head_size
542
            # Special case for Quantized Weights.
543
544
545
546
547
548
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
                shard_size = shard_size // param.pack_factor
                shard_offset = shard_offset // param.pack_factor
549

550
                # Special case for Marlin.
551
552
553
                shard_size, shard_offset = adjust_marlin_shard(
                    param, shard_size, shard_offset)

554
555
            param_data = param_data.narrow(output_dim, shard_offset,
                                           shard_size)
556
557
558
559
            if loaded_shard_id == "q":
                shard_id = tp_rank
            else:
                shard_id = tp_rank // self.num_kv_head_replicas
560
561
562
            start_idx = shard_id * shard_size
            loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                 shard_size)
563
        # Special case for for AQLM codebooks.
James Fleming's avatar
James Fleming committed
564
565
566
567
568
569
        elif is_metadata:
            # metadata indicates fixed size concatenated along dim 0
            shard_size = loaded_weight.shape[0]
            shard_index = ["q", "k", "v"].index(loaded_shard_id)
            param_data = param_data.narrow(0, shard_index * shard_size,
                                           shard_size)
570
571
572
573
        # Special case for Fp8 scales.
        elif fp8_scales_shard_indexer is not None:
            param_data, loaded_weight = fp8_scales_shard_indexer(
                param_data, loaded_weight, loaded_shard_id)
574
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
575
576
577
578
579
580
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "QKVParallelLinear, assume the weight is the same "
                    "for all partitions.")
581
582
583
584
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)


585
class RowParallelLinear(LinearBase):
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
    """Linear layer with row parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its first dimension and X along its second dimension as:
               -   -
              | A_1 |
              | .   |
          A = | .   |        X = [X_1, ..., X_p]
              | .   |
              | A_p |
               -   -
    Arguments:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias. Note that bias is not parallelized.
        input_is_parallel: If true, we assume that the input is already
                           split across the GPUs and we do not split
                           again.
        skip_bias_add: This was added to enable performance optimization where
                       bias can be fused with other element-wise operations.
                       We skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
608
        quant_config: Quantization configure.
609
610
611
612
613
614
615
616
617
618
619
    """

    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        input_is_parallel: bool = True,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        reduce_results: bool = True,
620
        quant_config: Optional[QuantizationConfig] = None,
621
    ):
622
623
624
        super().__init__(input_size, output_size, skip_bias_add, params_dtype,
                         quant_config)

625
626
627
628
629
630
        self.input_is_parallel = input_is_parallel
        self.reduce_results = reduce_results

        # Divide the weight matrix along the last dimension.
        self.tp_size = get_tensor_model_parallel_world_size()
        self.input_size_per_partition = divide(input_size, self.tp_size)
631
632
        # All the linear layer supports quant method.
        assert self.quant_method is not None
633
634
635
636
637
638
639
        self.quant_method.create_weights(self,
                                         self.input_size_per_partition,
                                         [self.output_size],
                                         self.input_size,
                                         self.output_size,
                                         self.params_dtype,
                                         weight_loader=self.weight_loader)
640
641
642
643
644
645
646

        if not reduce_results and (bias and not skip_bias_add):
            raise ValueError("When not reduce the results, adding bias to the "
                             "results can lead to incorrect results")

        if bias:
            self.bias = Parameter(
647
                torch.empty(self.output_size, dtype=params_dtype))
648
649
650
651
652
653
654
655
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
        else:
            self.register_parameter("bias", None)

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
656
657
658
659
        # Special case for Fp8 scales.
        fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
                                           None)

660
661
662
663
664
665
666
667
        tp_rank = get_tensor_model_parallel_rank()
        input_dim = getattr(param, "input_dim", None)
        param_data = param.data
        if input_dim is not None:
            shard_size = param_data.shape[input_dim]
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(input_dim, start_idx,
                                                 shard_size)
668
669
670
671
672
673
        # Special case for Fp8 scales.
        elif fp8_scales_shard_indexer is not None:
            param_data, loaded_weight = fp8_scales_shard_indexer(param_data,
                                                                 loaded_weight,
                                                                 shard_id=0)

674
675
676
677
678
679
680
681
682
683
684
685
686
687
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

    def forward(self, input_):
        # Set up backprop all-reduce.
        if self.input_is_parallel:
            input_parallel = input_
        else:
            tp_rank = get_tensor_model_parallel_rank()
            splitted_input = split_tensor_along_last_dim(
                input_, num_partitions=self.tp_size)
            input_parallel = splitted_input[tp_rank].contiguous()

        # Matrix multiply.
688
        assert self.quant_method is not None
689
        output_parallel = self.quant_method.apply(self, input_parallel)
690
691
692
693
694
695
696
697
698
699
700
701
        if self.reduce_results and self.tp_size > 1:
            output_ = tensor_model_parallel_all_reduce(output_parallel)
        else:
            output_ = output_parallel

        if not self.skip_bias_add:
            output = output_ + self.bias if self.bias is not None else output_
            output_bias = None
        else:
            output = output_
            output_bias = self.bias
        return output, output_bias
702
703
704
705
706
707
708
709

    def extra_repr(self) -> str:
        s = f"input_features={self.input_size_per_partition}"
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={self.tp_size}"
        s += f", reduce_results={self.reduce_results}"
        return s