linear.py 33.6 KB
Newer Older
1
from abc import abstractmethod
2
from typing import List, Optional
3
4
5
6
7

import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter

8
9
10
11
12
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size,
                              split_tensor_along_last_dim,
                              tensor_model_parallel_all_gather,
                              tensor_model_parallel_all_reduce)
13
from vllm.logger import init_logger
14
15
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig, QuantizeMethodBase)
16
17
18
19
20
from vllm.model_executor.utils import set_weight_attrs

logger = init_logger(__name__)


21
22
23
24
25
26
27
28
def adjust_marlin_shard(param, shard_size, shard_offset):
    marlin_tile_size = getattr(param, "marlin_tile_size", None)
    if marlin_tile_size is None:
        return shard_size, shard_offset

    return shard_size * marlin_tile_size, shard_offset * marlin_tile_size


29
class LinearMethodBase(QuantizeMethodBase):
30
31
32
    """Base class for different (maybe quantized) linear methods."""

    @abstractmethod
33
34
    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
James Fleming's avatar
James Fleming committed
35
                       output_partition_sizes: List[int], input_size: int,
36
37
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
38
39
40
41
42
43
44
45
46
47
48
49
50
        """Create weights for a linear layer. 
           The weights will be set as attributes of the layer.
        
        Args:
            layer: The layer that is using the LinearMethodBase factory.
            input_size_per_partition: Size of the weight input dim on rank X.
            output_partition_sizes: Sizes of the output dim of each logical 
                weight on rank X. E.g., output_partition_sizes for QKVLinear
                is a list contains the width of Wq, Wk, Wv on rank X.
            input_size: Size of the input dim of the weight across all ranks.
            output_size: Size of the output dim of the weight across all ranks.
            params_dtype: Datatype of the parameters.
        """
51
52
53
        raise NotImplementedError

    @abstractmethod
54
55
56
57
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
58
59
        """Apply the weights in layer to the input tensor.
        Expects create_weights to have been called before on the layer."""
60
61
62
63
64
65
66
67
68
69
70
71
72
73
        raise NotImplementedError


class UnquantizedLinearMethod(LinearMethodBase):
    """Linear method without quantization.

    Args:
        separate_bias_add: If true, add bias separately after matrix
                           multiplication.
    """

    def __init__(self, separate_bias_add: bool = False):
        self.separate_bias_add = separate_bias_add

74
75
    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
James Fleming's avatar
James Fleming committed
76
                       output_partition_sizes: List[int], input_size: int,
77
78
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
79
        weight = Parameter(torch.empty(sum(output_partition_sizes),
CHU Tianxiang's avatar
CHU Tianxiang committed
80
                                       input_size_per_partition,
81
82
83
                                       dtype=params_dtype),
                           requires_grad=False)
        set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
84
85
        layer.register_parameter("weight", weight)
        set_weight_attrs(weight, extra_weight_attrs)
86

87
88
89
90
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
91
        weight = layer.weight
92
        if self.separate_bias_add:
93
            if bias is not None:
94
95
96
97
98
                return F.linear(x, weight) + bias
            return F.linear(x, weight)
        return F.linear(x, weight, bias)


99
100
class LinearBase(torch.nn.Module):
    """Base linear layer.
101
102
103
104
105
106
107

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        bias: If true, add bias.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
108
        quant_config: Quantization configure.
109
110
111
112
113
114
115
116
    """

    def __init__(
        self,
        input_size: int,
        output_size: int,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
117
        quant_config: Optional[QuantizationConfig] = None,
118
119
120
121
122
123
124
125
126
127
    ):
        super().__init__()

        # Keep input parameters
        self.input_size = input_size
        self.output_size = output_size
        self.skip_bias_add = skip_bias_add
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype
128
        if quant_config is None:
129
130
            self.quant_method: Optional[
                QuantizeMethodBase] = UnquantizedLinearMethod()
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
        else:
            self.quant_method = quant_config.get_quant_method(self)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError


class ReplicatedLinear(LinearBase):
    """Replicated linear layer.

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        bias: If true, add bias.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
        quant_config: Quantization configure.
    """

150
151
152
153
154
155
156
    def __init__(self,
                 input_size: int,
                 output_size: int,
                 bias: bool = True,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
                 quant_config: Optional[QuantizationConfig] = None):
157
158
159
        super().__init__(input_size, output_size, skip_bias_add, params_dtype,
                         quant_config)

160
161
        # All the linear layer supports quant method.
        assert self.quant_method is not None
162
163
164
165
        self.quant_method.create_weights(self, self.input_size,
                                         [self.output_size], self.input_size,
                                         self.output_size, self.params_dtype)

166
167
        if bias:
            self.bias = Parameter(
168
                torch.empty(self.output_size, dtype=self.params_dtype))
169
170
171
172
173
174
            set_weight_attrs(self.bias, {"output_dim": 0})
        else:
            self.register_parameter("bias", None)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        bias = self.bias if not self.skip_bias_add else None
175
        assert self.quant_method is not None
176
        output = self.quant_method.apply(self, x, bias)
177
178
179
        output_bias = self.bias if self.skip_bias_add else None
        return output, output_bias

180
181
182
183
184
185
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        return s

186

187
class ColumnParallelLinear(LinearBase):
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
    """Linear layer with column parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its second dimension as A = [A_1, ..., A_p].

    Args:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y_i = XA_i
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
204
        quant_config: Quantization configure.
James Fleming's avatar
James Fleming committed
205
206
        output_sizes: list of output sizes packed into one output, like for QKV
                       the list would be size 3.
207
208
    """

209
210
211
212
213
214
215
216
217
    def __init__(self,
                 input_size: int,
                 output_size: int,
                 bias: bool = True,
                 gather_output: bool = False,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
                 quant_config: Optional[QuantizationConfig] = None,
                 output_sizes: Optional[List[int]] = None):
218
219
        super().__init__(input_size, output_size, skip_bias_add, params_dtype,
                         quant_config)
220
221

        self.gather_output = gather_output
222

223
224
        # Divide the weight matrix along the last dimension.
        tp_size = get_tensor_model_parallel_world_size()
225
226
227
228
229
230
231
232
233
234
        assert self.quant_method is not None
        self.output_size_per_partition = divide(self.output_size, tp_size)
        self.output_partition_sizes = [self.output_size_per_partition]
        # If QKV or MergedColumn, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = [
                divide(output_size, tp_size)
                for output_size in self.output_sizes
            ]

James Fleming's avatar
James Fleming committed
235
236
        if output_sizes is None:
            output_sizes = [output_size]
237
238
239
240
241
242
243
244
        self.quant_method.create_weights(
            layer=self,
            input_size_per_partition=self.input_size,
            output_partition_sizes=self.output_partition_sizes,
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
            weight_loader=self.weight_loader)
245
246
247
248
249
250
251
252
253
254
255
256
        if bias:
            self.bias = Parameter(
                torch.empty(self.output_size_per_partition,
                            dtype=params_dtype))
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
        else:
            self.register_parameter("bias", None)

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
257
258
259
260
        # Special case for Fp8 scales.
        fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
                                           None)

261
262
263
264
265
266
267
268
        tp_rank = get_tensor_model_parallel_rank()
        output_dim = getattr(param, "output_dim", None)
        param_data = param.data
        if output_dim is not None:
            shard_size = param_data.shape[output_dim]
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                 shard_size)
269
270
271
272
273
274
        # Special case for Fp8 scales.
        elif fp8_scales_shard_indexer is not None:
            param_data, loaded_weight = fp8_scales_shard_indexer(param_data,
                                                                 loaded_weight,
                                                                 shard_id=0)

275
276
277
278
279
280
281
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

    def forward(self, input_):
        bias = self.bias if not self.skip_bias_add else None

        # Matrix multiply.
282
        assert self.quant_method is not None
283
        output_parallel = self.quant_method.apply(self, input_, bias)
284
285
286
287
288
289
290
291
        if self.gather_output:
            # All-gather across the partitions.
            output = tensor_model_parallel_all_gather(output_parallel)
        else:
            output = output_parallel
        output_bias = self.bias if self.skip_bias_add else None
        return output, output_bias

292
293
294
295
296
297
298
299
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size_per_partition}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={get_tensor_model_parallel_world_size()}"
        s += f", gather_output={self.gather_output}"
        return s

300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318

class MergedColumnParallelLinear(ColumnParallelLinear):
    """Packed linear layers with column parallelism.

    Similar to ColumnParallelLinear, but the weight matrix is concatenated
    along the output dimension. When the weight matrix is loaded, the
    different partitions are sharded separately.

    Args:
        input_size: input dimension of the linear layer.
        output_sizes: list of output dimensions of the linear layer.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make the output
                       available to all GPUs, otherwise, every GPU will have
                       its own output.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
319
        quant_config: Quantization configure.
320
321
    """

322
323
324
325
326
327
328
329
    def __init__(self,
                 input_size: int,
                 output_sizes: List[int],
                 bias: bool = True,
                 gather_output: bool = False,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
                 quant_config: Optional[QuantizationConfig] = None):
330
331
332
        self.output_sizes = output_sizes
        tp_size = get_tensor_model_parallel_world_size()
        assert all(output_size % tp_size == 0 for output_size in output_sizes)
333
334
335
336
337
338
339
        super().__init__(input_size=input_size,
                         output_size=sum(output_sizes),
                         bias=bias,
                         gather_output=gather_output,
                         skip_bias_add=skip_bias_add,
                         params_dtype=params_dtype,
                         quant_config=quant_config)
340
341
342
343
344

    def weight_loader(self,
                      param: Parameter,
                      loaded_weight: torch.Tensor,
                      loaded_shard_id: Optional[int] = None):
James Fleming's avatar
James Fleming committed
345

346
347
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
348
        # Special case for AQLM codebooks.
James Fleming's avatar
James Fleming committed
349
        is_metadata = getattr(param, "is_metadata", False)
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369

        param_shard_splitter = getattr(param, "shard_splitter", None)

        if output_dim is not None and param_shard_splitter is not None:
            raise NotImplementedError(
                "We do not currently support output_dim != None and "
                "shard_splitter != None for a parameter. Please open an issue."
            )
        # If a parameter has defined a shard_splitter to be used for
        # the weight, it should be applied before the weight is
        # loaded/copied to the parameter. The shard_splitter applies
        # logic by using the loaded_shard_id to ensure that the loaded
        # param is loaded to the correct location
        # within the parameter defined by the linear method.
        if loaded_shard_id is None and param_shard_splitter is not None:
            raise NotImplementedError(
                "We do not currently support loaded_shard_id == None and "
                "shard_splitter != None for a parameter. Please open an issue."
            )

370
371
372
373
        # Special case for Fp8 scales.
        fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
                                           None)

374
375
376
377
378
379
380
381
382
383
384
385
386
        if loaded_shard_id is None:
            # Loaded weight is already packed.
            if output_dim is None:
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            current_shard_offset = 0
            shard_offsets = []
            for i, output_size in enumerate(self.output_sizes):
                shard_offsets.append((i, current_shard_offset, output_size))
                current_shard_offset += output_size
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
387
                # Special case for Quantization.
388
389
390
391
392
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
                    shard_size = shard_size // param.pack_factor
                    shard_offset = shard_offset // param.pack_factor
393
                    # Special case for Marlin.
394
395
396
                    shard_size, shard_offset = adjust_marlin_shard(
                        param, shard_size, shard_offset)

397
398
399
400
401
402
403
404
405
406
407
                loaded_weight_shard = loaded_weight.narrow(
                    output_dim, shard_offset, shard_size)
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id < len(self.output_sizes)
        tp_rank = get_tensor_model_parallel_rank()
        tp_size = get_tensor_model_parallel_world_size()
        if output_dim is not None:
            shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
            shard_size = self.output_sizes[loaded_shard_id] // tp_size
408
            # Special case for quantization.
409
410
411
412
413
414
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
                shard_size = shard_size // param.pack_factor
                shard_offset = shard_offset // param.pack_factor
415
                # Special case for Marlin.
416
417
418
                shard_size, shard_offset = adjust_marlin_shard(
                    param, shard_size, shard_offset)

419
420
421
422
423
            param_data = param_data.narrow(output_dim, shard_offset,
                                           shard_size)
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                 shard_size)
424
        # Special case for AQLM codebooks.
James Fleming's avatar
James Fleming committed
425
426
427
428
429
        elif is_metadata:
            # metadata indicates fixed size concatenated along dim 0
            shard_size = loaded_weight.shape[0]
            shard_offset = loaded_shard_id * shard_size
            param_data = param_data.narrow(0, shard_offset, shard_size)
430
431
432
433
434
435
436

        # If a param_shard_splitter is defined by the LinearMethod, use it.
        elif param_shard_splitter is not None:
            logical_widths = getattr(param, "logical_widths", None)
            param_data, loaded_weight = param_shard_splitter(
                param_data, loaded_weight, loaded_shard_id, logical_widths)

437
438
439
440
441
        # Special case for Fp8 scales.
        elif fp8_scales_shard_indexer is not None:
            param_data, loaded_weight = fp8_scales_shard_indexer(
                param_data, loaded_weight, loaded_shard_id)

442
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
443
444
445
446
447
448
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "MergedColumnParallelLinear, assume the weight is "
                    "the same for all partitions.")
449
450
451
452
453
454
455
456

        if fp8_scales_shard_indexer is None:
            if len(param_data.shape) == 0:
                param_data = param_data.reshape(1)

            if len(loaded_weight.shape) == 0:
                loaded_weight = loaded_weight.reshape(1)

457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)


class QKVParallelLinear(ColumnParallelLinear):
    """Linear layers for the attention's QKV transformation.

    Linear layers for the linear transformation of the query, key, and value
    vectors in the attention layer. The weight matrix is concatenated along
    the output dimension. The layer is parallelized along the head dimension.
    When the number of key/value heads is smaller than the number of query
    heads (e.g., multi-query/grouped-query attention), the key/value head may
    be replicated while the query heads are partitioned.

    Args:
        hidden_size: input hidden state size of the transformer.
        head_size: size of each attention head.
        total_num_heads: total number of attention query heads.
        total_num_kv_heads: total number of attention key/value heads. If
                            None, assume total_num_kv_heads = total_num_heads.
        bias: If true, add bias.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
482
        quant_config: Quantization configure.
483
484
    """

485
486
487
488
489
490
491
492
493
    def __init__(self,
                 hidden_size: int,
                 head_size: int,
                 total_num_heads: int,
                 total_num_kv_heads: Optional[int] = None,
                 bias: bool = True,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
                 quant_config: Optional[QuantizationConfig] = None):
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
        self.hidden_size = hidden_size
        self.head_size = head_size
        self.total_num_heads = total_num_heads
        if total_num_kv_heads is None:
            total_num_kv_heads = total_num_heads
        self.total_num_kv_heads = total_num_kv_heads
        # Divide the weight matrix along the last dimension.
        tp_size = get_tensor_model_parallel_world_size()
        self.num_heads = divide(self.total_num_heads, tp_size)
        if tp_size >= self.total_num_kv_heads:
            self.num_kv_heads = 1
            self.num_kv_head_replicas = divide(tp_size,
                                               self.total_num_kv_heads)
        else:
            self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
            self.num_kv_head_replicas = 1
        input_size = self.hidden_size
        output_size = (self.num_heads +
                       2 * self.num_kv_heads) * tp_size * self.head_size
513
514
515
516
        self.output_sizes = [
            self.num_heads * self.head_size * tp_size,  # q_proj
            self.num_kv_heads * self.head_size * tp_size,  # k_proj
            self.num_kv_heads * self.head_size * tp_size,  # v_proj 
James Fleming's avatar
James Fleming committed
517
518
        ]

519
520
521
522
523
524
525
        super().__init__(input_size=input_size,
                         output_size=output_size,
                         bias=bias,
                         gather_output=False,
                         skip_bias_add=skip_bias_add,
                         params_dtype=params_dtype,
                         quant_config=quant_config)
526
527
528
529
530
531
532

    def weight_loader(self,
                      param: Parameter,
                      loaded_weight: torch.Tensor,
                      loaded_shard_id: Optional[str] = None):
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
533
        # Special case for AQLM codebooks.
James Fleming's avatar
James Fleming committed
534
        is_metadata = getattr(param, "is_metadata", False)
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554

        param_shard_splitter = getattr(param, "shard_splitter", None)

        if output_dim is not None and param_shard_splitter is not None:
            raise NotImplementedError(
                "We do not currently support output_dim != None and "
                "shard_splitter != None for a parameter. Please open an issue."
            )
        # If a parameter has defined a shard_splitter to be used for
        # the weight, it should be applied before the weight is
        # loaded/copied to the parameter. The shard_splitter applies
        # logic by using the loaded_shard_id to ensure that the loaded
        # param is loaded to the correct location
        # within the parameter defined by the linear method.
        if loaded_shard_id is None and param_shard_splitter is not None:
            raise NotImplementedError(
                "We do not currently support loaded_shard_id == None and "
                "shard_splitter != None for a parameter. Please open an issue."
            )

555
556
557
        # Special case for Fp8 scales.
        fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
                                           None)
558

559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
        if loaded_shard_id is None:
            # Loaded weight is already packed.
            if output_dim is None:
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            shard_offsets = [
                # (shard_id, shard_offset, shard_size)
                ("q", 0, self.total_num_heads * self.head_size),
                ("k", self.total_num_heads * self.head_size,
                 self.total_num_kv_heads * self.head_size),
                ("v", (self.total_num_heads + self.total_num_kv_heads) *
                 self.head_size, self.total_num_kv_heads * self.head_size),
            ]
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
575
                # Special case for Quantized Weights.
576
577
578
579
580
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
                    shard_size = shard_size // param.pack_factor
                    shard_offset = shard_offset // param.pack_factor
581

582
                    # Special case for Marlin.
583
584
585
                    shard_size, shard_offset = adjust_marlin_shard(
                        param, shard_size, shard_offset)

586
587
588
589
590
591
592
                loaded_weight_shard = loaded_weight.narrow(
                    output_dim, shard_offset, shard_size)
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        tp_rank = get_tensor_model_parallel_rank()
        assert loaded_shard_id in ["q", "k", "v"]
593
594

        # If output dim is defined, use the default loading process.
595
596
597
598
599
600
601
602
603
604
605
        if output_dim is not None:
            if loaded_shard_id == "q":
                shard_offset = 0
                shard_size = self.num_heads * self.head_size
            elif loaded_shard_id == "k":
                shard_offset = self.num_heads * self.head_size
                shard_size = self.num_kv_heads * self.head_size
            elif loaded_shard_id == "v":
                shard_offset = (self.num_heads +
                                self.num_kv_heads) * self.head_size
                shard_size = self.num_kv_heads * self.head_size
606
            # Special case for Quantized Weights.
607
608
609
610
611
612
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
                shard_size = shard_size // param.pack_factor
                shard_offset = shard_offset // param.pack_factor
613

614
                # Special case for Marlin.
615
616
617
                shard_size, shard_offset = adjust_marlin_shard(
                    param, shard_size, shard_offset)

618
619
            param_data = param_data.narrow(output_dim, shard_offset,
                                           shard_size)
620
621
622
623
            if loaded_shard_id == "q":
                shard_id = tp_rank
            else:
                shard_id = tp_rank // self.num_kv_head_replicas
624
625
626
            start_idx = shard_id * shard_size
            loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                 shard_size)
627
        # Special case for for AQLM codebooks.
James Fleming's avatar
James Fleming committed
628
629
630
631
632
633
        elif is_metadata:
            # metadata indicates fixed size concatenated along dim 0
            shard_size = loaded_weight.shape[0]
            shard_index = ["q", "k", "v"].index(loaded_shard_id)
            param_data = param_data.narrow(0, shard_index * shard_size,
                                           shard_size)
634
635
636
637
638
639
        # If a param_shard_splitter is defined by the LinearMethod, use it.
        elif param_shard_splitter is not None:
            logical_widths = getattr(param, "logical_widths", None)
            param_data, loaded_weight = param_shard_splitter(
                param_data, loaded_weight, loaded_shard_id, logical_widths)

640
641
642
643
        # Special case for Fp8 scales.
        elif fp8_scales_shard_indexer is not None:
            param_data, loaded_weight = fp8_scales_shard_indexer(
                param_data, loaded_weight, loaded_shard_id)
644
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
645
646
647
648
649
650
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "QKVParallelLinear, assume the weight is the same "
                    "for all partitions.")
651
652
653
654
655
656
657

        if len(param_data.shape) == 0:
            param_data = param_data.reshape(1)

        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)

658
659
660
661
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)


662
class RowParallelLinear(LinearBase):
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
    """Linear layer with row parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its first dimension and X along its second dimension as:
               -   -
              | A_1 |
              | .   |
          A = | .   |        X = [X_1, ..., X_p]
              | .   |
              | A_p |
               -   -
    Arguments:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias. Note that bias is not parallelized.
        input_is_parallel: If true, we assume that the input is already
                           split across the GPUs and we do not split
                           again.
        skip_bias_add: This was added to enable performance optimization where
                       bias can be fused with other element-wise operations.
                       We skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
685
        quant_config: Quantization configure.
686
687
    """

688
689
690
691
692
693
694
695
696
    def __init__(self,
                 input_size: int,
                 output_size: int,
                 bias: bool = True,
                 input_is_parallel: bool = True,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
                 reduce_results: bool = True,
                 quant_config: Optional[QuantizationConfig] = None):
697
698
699
        super().__init__(input_size, output_size, skip_bias_add, params_dtype,
                         quant_config)

700
701
702
703
704
705
        self.input_is_parallel = input_is_parallel
        self.reduce_results = reduce_results

        # Divide the weight matrix along the last dimension.
        self.tp_size = get_tensor_model_parallel_world_size()
        self.input_size_per_partition = divide(input_size, self.tp_size)
706
        assert self.quant_method is not None
707
708
709
710
711
712
713
714
        self.quant_method.create_weights(
            layer=self,
            input_size_per_partition=self.input_size_per_partition,
            output_partition_sizes=[self.output_size],
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
            weight_loader=self.weight_loader)
715
716
717
718
719
720
        if not reduce_results and (bias and not skip_bias_add):
            raise ValueError("When not reduce the results, adding bias to the "
                             "results can lead to incorrect results")

        if bias:
            self.bias = Parameter(
721
                torch.empty(self.output_size, dtype=params_dtype))
722
723
724
725
726
727
728
729
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
        else:
            self.register_parameter("bias", None)

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
730
731
732
733
        # Special case for Fp8 scales.
        fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
                                           None)

734
735
736
737
738
739
740
741
        tp_rank = get_tensor_model_parallel_rank()
        input_dim = getattr(param, "input_dim", None)
        param_data = param.data
        if input_dim is not None:
            shard_size = param_data.shape[input_dim]
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(input_dim, start_idx,
                                                 shard_size)
742

743
744
745
746
747
748
        # Special case for Fp8 scales.
        elif fp8_scales_shard_indexer is not None:
            param_data, loaded_weight = fp8_scales_shard_indexer(param_data,
                                                                 loaded_weight,
                                                                 shard_id=0)

749
750
751
        if fp8_scales_shard_indexer is None and len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)

752
753
754
755
756
757
758
759
760
761
762
763
764
765
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

    def forward(self, input_):
        # Set up backprop all-reduce.
        if self.input_is_parallel:
            input_parallel = input_
        else:
            tp_rank = get_tensor_model_parallel_rank()
            splitted_input = split_tensor_along_last_dim(
                input_, num_partitions=self.tp_size)
            input_parallel = splitted_input[tp_rank].contiguous()

        # Matrix multiply.
766
        assert self.quant_method is not None
767
        output_parallel = self.quant_method.apply(self, input_parallel)
768
769
770
771
772
773
774
775
776
777
778
779
        if self.reduce_results and self.tp_size > 1:
            output_ = tensor_model_parallel_all_reduce(output_parallel)
        else:
            output_ = output_parallel

        if not self.skip_bias_add:
            output = output_ + self.bias if self.bias is not None else output_
            output_bias = None
        else:
            output = output_
            output_bias = self.bias
        return output, output_bias
780
781
782
783
784
785
786
787

    def extra_repr(self) -> str:
        s = f"input_features={self.input_size_per_partition}"
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={self.tp_size}"
        s += f", reduce_results={self.reduce_results}"
        return s