linear.py 34.8 KB
Newer Older
1
from abc import abstractmethod
2
from typing import Dict, List, Optional, Tuple
3
4
5
6
7

import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter

8
9
10
11
12
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size,
                              split_tensor_along_last_dim,
                              tensor_model_parallel_all_gather,
                              tensor_model_parallel_all_reduce)
13
from vllm.logger import init_logger
14
15
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig, QuantizeMethodBase)
16
17
18
19
20
from vllm.model_executor.utils import set_weight_attrs

logger = init_logger(__name__)


21
22
23
24
25
26
27
28
def adjust_marlin_shard(param, shard_size, shard_offset):
    marlin_tile_size = getattr(param, "marlin_tile_size", None)
    if marlin_tile_size is None:
        return shard_size, shard_offset

    return shard_size * marlin_tile_size, shard_offset * marlin_tile_size


29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def adjust_bitsandbytes_shard(param: Parameter,
                              qkv_offsets: Dict[str, Tuple[int, int]],
                              loaded_shard_id: str) -> Tuple[int, int]:
    """Adjust the quantization offsets and sizes for BitsAndBytes sharding."""

    total, _ = qkv_offsets["total"]
    orig_offset, orig_size = qkv_offsets[loaded_shard_id]

    quantized_total = param.data.shape[0]
    quantized_offset = orig_offset * quantized_total // total
    quantized_size = orig_size * quantized_total // total

    return quantized_size, quantized_offset


44
class LinearMethodBase(QuantizeMethodBase):
45
46
47
    """Base class for different (maybe quantized) linear methods."""

    @abstractmethod
48
49
    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
James Fleming's avatar
James Fleming committed
50
                       output_partition_sizes: List[int], input_size: int,
51
52
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
53
54
        """Create weights for a linear layer. 
           The weights will be set as attributes of the layer.
55

56
57
58
59
60
61
62
63
64
65
        Args:
            layer: The layer that is using the LinearMethodBase factory.
            input_size_per_partition: Size of the weight input dim on rank X.
            output_partition_sizes: Sizes of the output dim of each logical 
                weight on rank X. E.g., output_partition_sizes for QKVLinear
                is a list contains the width of Wq, Wk, Wv on rank X.
            input_size: Size of the input dim of the weight across all ranks.
            output_size: Size of the output dim of the weight across all ranks.
            params_dtype: Datatype of the parameters.
        """
66
67
68
        raise NotImplementedError

    @abstractmethod
69
70
71
72
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
73
74
        """Apply the weights in layer to the input tensor.
        Expects create_weights to have been called before on the layer."""
75
76
77
78
79
80
81
82
83
84
85
86
87
88
        raise NotImplementedError


class UnquantizedLinearMethod(LinearMethodBase):
    """Linear method without quantization.

    Args:
        separate_bias_add: If true, add bias separately after matrix
                           multiplication.
    """

    def __init__(self, separate_bias_add: bool = False):
        self.separate_bias_add = separate_bias_add

89
90
    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
James Fleming's avatar
James Fleming committed
91
                       output_partition_sizes: List[int], input_size: int,
92
93
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
94
        weight = Parameter(torch.empty(sum(output_partition_sizes),
CHU Tianxiang's avatar
CHU Tianxiang committed
95
                                       input_size_per_partition,
96
97
98
                                       dtype=params_dtype),
                           requires_grad=False)
        set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
99
100
        layer.register_parameter("weight", weight)
        set_weight_attrs(weight, extra_weight_attrs)
101

102
103
104
105
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
106
        weight = layer.weight
107
        if self.separate_bias_add:
108
            if bias is not None:
109
110
111
112
113
                return F.linear(x, weight) + bias
            return F.linear(x, weight)
        return F.linear(x, weight, bias)


114
115
class LinearBase(torch.nn.Module):
    """Base linear layer.
116
117
118
119
120
121
122

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        bias: If true, add bias.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
123
        quant_config: Quantization configure.
124
125
126
127
128
129
130
131
    """

    def __init__(
        self,
        input_size: int,
        output_size: int,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
132
        quant_config: Optional[QuantizationConfig] = None,
133
134
135
136
137
138
139
140
141
142
    ):
        super().__init__()

        # Keep input parameters
        self.input_size = input_size
        self.output_size = output_size
        self.skip_bias_add = skip_bias_add
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype
143
        if quant_config is None:
144
145
            self.quant_method: Optional[
                QuantizeMethodBase] = UnquantizedLinearMethod()
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
        else:
            self.quant_method = quant_config.get_quant_method(self)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError


class ReplicatedLinear(LinearBase):
    """Replicated linear layer.

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        bias: If true, add bias.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
        quant_config: Quantization configure.
    """

165
166
167
168
169
170
171
    def __init__(self,
                 input_size: int,
                 output_size: int,
                 bias: bool = True,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
                 quant_config: Optional[QuantizationConfig] = None):
172
173
174
        super().__init__(input_size, output_size, skip_bias_add, params_dtype,
                         quant_config)

175
176
        # All the linear layer supports quant method.
        assert self.quant_method is not None
177
178
179
180
        self.quant_method.create_weights(self, self.input_size,
                                         [self.output_size], self.input_size,
                                         self.output_size, self.params_dtype)

181
182
        if bias:
            self.bias = Parameter(
183
                torch.empty(self.output_size, dtype=self.params_dtype))
184
185
186
187
188
189
            set_weight_attrs(self.bias, {"output_dim": 0})
        else:
            self.register_parameter("bias", None)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        bias = self.bias if not self.skip_bias_add else None
190
        assert self.quant_method is not None
191
        output = self.quant_method.apply(self, x, bias)
192
193
194
        output_bias = self.bias if self.skip_bias_add else None
        return output, output_bias

195
196
197
198
199
200
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        return s

201

202
class ColumnParallelLinear(LinearBase):
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
    """Linear layer with column parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its second dimension as A = [A_1, ..., A_p].

    Args:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y_i = XA_i
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
219
        quant_config: Quantization configure.
James Fleming's avatar
James Fleming committed
220
221
        output_sizes: list of output sizes packed into one output, like for QKV
                       the list would be size 3.
222
223
    """

224
225
226
227
228
229
230
231
232
    def __init__(self,
                 input_size: int,
                 output_size: int,
                 bias: bool = True,
                 gather_output: bool = False,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
                 quant_config: Optional[QuantizationConfig] = None,
                 output_sizes: Optional[List[int]] = None):
233
234
        super().__init__(input_size, output_size, skip_bias_add, params_dtype,
                         quant_config)
235
236

        self.gather_output = gather_output
237

238
239
        # Divide the weight matrix along the last dimension.
        tp_size = get_tensor_model_parallel_world_size()
240
241
242
243
244
245
246
247
248
249
        assert self.quant_method is not None
        self.output_size_per_partition = divide(self.output_size, tp_size)
        self.output_partition_sizes = [self.output_size_per_partition]
        # If QKV or MergedColumn, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = [
                divide(output_size, tp_size)
                for output_size in self.output_sizes
            ]

James Fleming's avatar
James Fleming committed
250
251
        if output_sizes is None:
            output_sizes = [output_size]
252
253
254
255
256
257
258
259
        self.quant_method.create_weights(
            layer=self,
            input_size_per_partition=self.input_size,
            output_partition_sizes=self.output_partition_sizes,
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
            weight_loader=self.weight_loader)
260
261
262
263
264
265
266
267
268
269
270
271
        if bias:
            self.bias = Parameter(
                torch.empty(self.output_size_per_partition,
                            dtype=params_dtype))
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
        else:
            self.register_parameter("bias", None)

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
272
273
274
275
        # Special case for Fp8 scales.
        fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
                                           None)

276
277
278
279
280
281
282
283
        tp_rank = get_tensor_model_parallel_rank()
        output_dim = getattr(param, "output_dim", None)
        param_data = param.data
        if output_dim is not None:
            shard_size = param_data.shape[output_dim]
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                 shard_size)
284
285
286
287
288
289
        # Special case for Fp8 scales.
        elif fp8_scales_shard_indexer is not None:
            param_data, loaded_weight = fp8_scales_shard_indexer(param_data,
                                                                 loaded_weight,
                                                                 shard_id=0)

290
291
292
293
294
295
296
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

    def forward(self, input_):
        bias = self.bias if not self.skip_bias_add else None

        # Matrix multiply.
297
        assert self.quant_method is not None
298
        output_parallel = self.quant_method.apply(self, input_, bias)
299
300
301
302
303
304
305
306
        if self.gather_output:
            # All-gather across the partitions.
            output = tensor_model_parallel_all_gather(output_parallel)
        else:
            output = output_parallel
        output_bias = self.bias if self.skip_bias_add else None
        return output, output_bias

307
308
309
310
311
312
313
314
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size_per_partition}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={get_tensor_model_parallel_world_size()}"
        s += f", gather_output={self.gather_output}"
        return s

315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333

class MergedColumnParallelLinear(ColumnParallelLinear):
    """Packed linear layers with column parallelism.

    Similar to ColumnParallelLinear, but the weight matrix is concatenated
    along the output dimension. When the weight matrix is loaded, the
    different partitions are sharded separately.

    Args:
        input_size: input dimension of the linear layer.
        output_sizes: list of output dimensions of the linear layer.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make the output
                       available to all GPUs, otherwise, every GPU will have
                       its own output.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
334
        quant_config: Quantization configure.
335
336
    """

337
338
339
340
341
342
343
344
    def __init__(self,
                 input_size: int,
                 output_sizes: List[int],
                 bias: bool = True,
                 gather_output: bool = False,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
                 quant_config: Optional[QuantizationConfig] = None):
345
346
347
        self.output_sizes = output_sizes
        tp_size = get_tensor_model_parallel_world_size()
        assert all(output_size % tp_size == 0 for output_size in output_sizes)
348
349
350
351
352
353
354
        super().__init__(input_size=input_size,
                         output_size=sum(output_sizes),
                         bias=bias,
                         gather_output=gather_output,
                         skip_bias_add=skip_bias_add,
                         params_dtype=params_dtype,
                         quant_config=quant_config)
355
356
357
358
359

    def weight_loader(self,
                      param: Parameter,
                      loaded_weight: torch.Tensor,
                      loaded_shard_id: Optional[int] = None):
James Fleming's avatar
James Fleming committed
360

361
362
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
363
        # Special case for AQLM codebooks.
James Fleming's avatar
James Fleming committed
364
        is_metadata = getattr(param, "is_metadata", False)
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384

        param_shard_splitter = getattr(param, "shard_splitter", None)

        if output_dim is not None and param_shard_splitter is not None:
            raise NotImplementedError(
                "We do not currently support output_dim != None and "
                "shard_splitter != None for a parameter. Please open an issue."
            )
        # If a parameter has defined a shard_splitter to be used for
        # the weight, it should be applied before the weight is
        # loaded/copied to the parameter. The shard_splitter applies
        # logic by using the loaded_shard_id to ensure that the loaded
        # param is loaded to the correct location
        # within the parameter defined by the linear method.
        if loaded_shard_id is None and param_shard_splitter is not None:
            raise NotImplementedError(
                "We do not currently support loaded_shard_id == None and "
                "shard_splitter != None for a parameter. Please open an issue."
            )

385
386
387
388
        # Special case for Fp8 scales.
        fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
                                           None)

389
390
391
392
393
394
395
        if loaded_shard_id is None:
            # Loaded weight is already packed.
            if output_dim is None:
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            current_shard_offset = 0
396
            shard_offsets: List[Tuple[int, int, int]] = []
397
398
399
400
401
            for i, output_size in enumerate(self.output_sizes):
                shard_offsets.append((i, current_shard_offset, output_size))
                current_shard_offset += output_size
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
402
                # Special case for Quantization.
403
404
405
406
407
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
                    shard_size = shard_size // param.pack_factor
                    shard_offset = shard_offset // param.pack_factor
408
                    # Special case for Marlin.
409
410
411
                    shard_size, shard_offset = adjust_marlin_shard(
                        param, shard_size, shard_offset)

412
413
414
415
416
417
418
419
420
421
422
                loaded_weight_shard = loaded_weight.narrow(
                    output_dim, shard_offset, shard_size)
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id < len(self.output_sizes)
        tp_rank = get_tensor_model_parallel_rank()
        tp_size = get_tensor_model_parallel_world_size()
        if output_dim is not None:
            shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
            shard_size = self.output_sizes[loaded_shard_id] // tp_size
423
            # Special case for quantization.
424
425
426
427
428
429
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
                shard_size = shard_size // param.pack_factor
                shard_offset = shard_offset // param.pack_factor
430
                # Special case for Marlin.
431
432
433
                shard_size, shard_offset = adjust_marlin_shard(
                    param, shard_size, shard_offset)

434
435
436
437
438
439
            use_bitsandbytes = getattr(param, "use_bitsandbytes", False)
            if use_bitsandbytes:
                shard_size = loaded_weight.shape[output_dim]
                shard_offset = loaded_weight.shape[output_dim] * \
                    loaded_shard_id

440
441
442
443
444
            param_data = param_data.narrow(output_dim, shard_offset,
                                           shard_size)
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                 shard_size)
445
        # Special case for AQLM codebooks.
James Fleming's avatar
James Fleming committed
446
447
448
449
450
        elif is_metadata:
            # metadata indicates fixed size concatenated along dim 0
            shard_size = loaded_weight.shape[0]
            shard_offset = loaded_shard_id * shard_size
            param_data = param_data.narrow(0, shard_offset, shard_size)
451
452
453
454
455
456
457

        # If a param_shard_splitter is defined by the LinearMethod, use it.
        elif param_shard_splitter is not None:
            logical_widths = getattr(param, "logical_widths", None)
            param_data, loaded_weight = param_shard_splitter(
                param_data, loaded_weight, loaded_shard_id, logical_widths)

458
459
460
461
462
        # Special case for Fp8 scales.
        elif fp8_scales_shard_indexer is not None:
            param_data, loaded_weight = fp8_scales_shard_indexer(
                param_data, loaded_weight, loaded_shard_id)

463
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
464
465
466
467
468
469
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "MergedColumnParallelLinear, assume the weight is "
                    "the same for all partitions.")
470

471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)


class QKVParallelLinear(ColumnParallelLinear):
    """Linear layers for the attention's QKV transformation.

    Linear layers for the linear transformation of the query, key, and value
    vectors in the attention layer. The weight matrix is concatenated along
    the output dimension. The layer is parallelized along the head dimension.
    When the number of key/value heads is smaller than the number of query
    heads (e.g., multi-query/grouped-query attention), the key/value head may
    be replicated while the query heads are partitioned.

    Args:
        hidden_size: input hidden state size of the transformer.
        head_size: size of each attention head.
        total_num_heads: total number of attention query heads.
        total_num_kv_heads: total number of attention key/value heads. If
                            None, assume total_num_kv_heads = total_num_heads.
        bias: If true, add bias.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
496
        quant_config: Quantization configure.
497
498
    """

499
500
501
502
503
504
505
506
507
    def __init__(self,
                 hidden_size: int,
                 head_size: int,
                 total_num_heads: int,
                 total_num_kv_heads: Optional[int] = None,
                 bias: bool = True,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
                 quant_config: Optional[QuantizationConfig] = None):
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
        self.hidden_size = hidden_size
        self.head_size = head_size
        self.total_num_heads = total_num_heads
        if total_num_kv_heads is None:
            total_num_kv_heads = total_num_heads
        self.total_num_kv_heads = total_num_kv_heads
        # Divide the weight matrix along the last dimension.
        tp_size = get_tensor_model_parallel_world_size()
        self.num_heads = divide(self.total_num_heads, tp_size)
        if tp_size >= self.total_num_kv_heads:
            self.num_kv_heads = 1
            self.num_kv_head_replicas = divide(tp_size,
                                               self.total_num_kv_heads)
        else:
            self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
            self.num_kv_head_replicas = 1
        input_size = self.hidden_size
        output_size = (self.num_heads +
                       2 * self.num_kv_heads) * tp_size * self.head_size
527
528
529
530
        self.output_sizes = [
            self.num_heads * self.head_size * tp_size,  # q_proj
            self.num_kv_heads * self.head_size * tp_size,  # k_proj
            self.num_kv_heads * self.head_size * tp_size,  # v_proj 
James Fleming's avatar
James Fleming committed
531
532
        ]

533
534
535
536
537
538
539
        super().__init__(input_size=input_size,
                         output_size=output_size,
                         bias=bias,
                         gather_output=False,
                         skip_bias_add=skip_bias_add,
                         params_dtype=params_dtype,
                         quant_config=quant_config)
540
541
542
543
544
545
546

    def weight_loader(self,
                      param: Parameter,
                      loaded_weight: torch.Tensor,
                      loaded_shard_id: Optional[str] = None):
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
547
        # Special case for AQLM codebooks.
James Fleming's avatar
James Fleming committed
548
        is_metadata = getattr(param, "is_metadata", False)
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568

        param_shard_splitter = getattr(param, "shard_splitter", None)

        if output_dim is not None and param_shard_splitter is not None:
            raise NotImplementedError(
                "We do not currently support output_dim != None and "
                "shard_splitter != None for a parameter. Please open an issue."
            )
        # If a parameter has defined a shard_splitter to be used for
        # the weight, it should be applied before the weight is
        # loaded/copied to the parameter. The shard_splitter applies
        # logic by using the loaded_shard_id to ensure that the loaded
        # param is loaded to the correct location
        # within the parameter defined by the linear method.
        if loaded_shard_id is None and param_shard_splitter is not None:
            raise NotImplementedError(
                "We do not currently support loaded_shard_id == None and "
                "shard_splitter != None for a parameter. Please open an issue."
            )

569
570
571
        # Special case for Fp8 scales.
        fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
                                           None)
572

573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
        if loaded_shard_id is None:
            # Loaded weight is already packed.
            if output_dim is None:
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            shard_offsets = [
                # (shard_id, shard_offset, shard_size)
                ("q", 0, self.total_num_heads * self.head_size),
                ("k", self.total_num_heads * self.head_size,
                 self.total_num_kv_heads * self.head_size),
                ("v", (self.total_num_heads + self.total_num_kv_heads) *
                 self.head_size, self.total_num_kv_heads * self.head_size),
            ]
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
589
                # Special case for Quantized Weights.
590
591
592
593
594
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
                    shard_size = shard_size // param.pack_factor
                    shard_offset = shard_offset // param.pack_factor
595

596
                    # Special case for Marlin.
597
598
599
                    shard_size, shard_offset = adjust_marlin_shard(
                        param, shard_size, shard_offset)

600
601
602
603
604
605
606
                loaded_weight_shard = loaded_weight.narrow(
                    output_dim, shard_offset, shard_size)
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        tp_rank = get_tensor_model_parallel_rank()
        assert loaded_shard_id in ["q", "k", "v"]
607
608

        # If output dim is defined, use the default loading process.
609
610
611
612
613
614
615
616
617
618
619
        if output_dim is not None:
            if loaded_shard_id == "q":
                shard_offset = 0
                shard_size = self.num_heads * self.head_size
            elif loaded_shard_id == "k":
                shard_offset = self.num_heads * self.head_size
                shard_size = self.num_kv_heads * self.head_size
            elif loaded_shard_id == "v":
                shard_offset = (self.num_heads +
                                self.num_kv_heads) * self.head_size
                shard_size = self.num_kv_heads * self.head_size
620
            # Special case for Quantized Weights.
621
622
623
624
625
626
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
                shard_size = shard_size // param.pack_factor
                shard_offset = shard_offset // param.pack_factor
627

628
                # Special case for Marlin.
629
630
631
                shard_size, shard_offset = adjust_marlin_shard(
                    param, shard_size, shard_offset)

632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
            use_bitsandbytes = getattr(param, "use_bitsandbytes", False)
            if use_bitsandbytes:
                orig_qkv_offsets = {
                    "q": (0, self.num_heads * self.head_size),
                    "k": (self.num_heads * self.head_size,
                          self.num_kv_heads * self.head_size),
                    "v":
                    ((self.num_heads + self.num_kv_heads) * self.head_size,
                     self.num_kv_heads * self.head_size),
                    "total":
                    ((self.num_heads + 2 * self.num_kv_heads) * self.head_size,
                     0)
                }
                shard_size, shard_offset = adjust_bitsandbytes_shard(
                    param, orig_qkv_offsets, loaded_shard_id)

648
649
            param_data = param_data.narrow(output_dim, shard_offset,
                                           shard_size)
650
651
652
653
            if loaded_shard_id == "q":
                shard_id = tp_rank
            else:
                shard_id = tp_rank // self.num_kv_head_replicas
654
655
656
            start_idx = shard_id * shard_size
            loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                 shard_size)
657
        # Special case for for AQLM codebooks.
James Fleming's avatar
James Fleming committed
658
659
660
661
662
663
        elif is_metadata:
            # metadata indicates fixed size concatenated along dim 0
            shard_size = loaded_weight.shape[0]
            shard_index = ["q", "k", "v"].index(loaded_shard_id)
            param_data = param_data.narrow(0, shard_index * shard_size,
                                           shard_size)
664
665
666
667
668
669
        # If a param_shard_splitter is defined by the LinearMethod, use it.
        elif param_shard_splitter is not None:
            logical_widths = getattr(param, "logical_widths", None)
            param_data, loaded_weight = param_shard_splitter(
                param_data, loaded_weight, loaded_shard_id, logical_widths)

670
671
672
673
        # Special case for Fp8 scales.
        elif fp8_scales_shard_indexer is not None:
            param_data, loaded_weight = fp8_scales_shard_indexer(
                param_data, loaded_weight, loaded_shard_id)
674
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
675
676
677
678
679
680
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "QKVParallelLinear, assume the weight is the same "
                    "for all partitions.")
681

682
683
684
685
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)


686
class RowParallelLinear(LinearBase):
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
    """Linear layer with row parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its first dimension and X along its second dimension as:
               -   -
              | A_1 |
              | .   |
          A = | .   |        X = [X_1, ..., X_p]
              | .   |
              | A_p |
               -   -
    Arguments:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias. Note that bias is not parallelized.
        input_is_parallel: If true, we assume that the input is already
                           split across the GPUs and we do not split
                           again.
        skip_bias_add: This was added to enable performance optimization where
                       bias can be fused with other element-wise operations.
                       We skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
709
        quant_config: Quantization configure.
710
711
    """

712
713
714
715
716
717
718
719
720
    def __init__(self,
                 input_size: int,
                 output_size: int,
                 bias: bool = True,
                 input_is_parallel: bool = True,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
                 reduce_results: bool = True,
                 quant_config: Optional[QuantizationConfig] = None):
721
722
723
        super().__init__(input_size, output_size, skip_bias_add, params_dtype,
                         quant_config)

724
725
726
727
728
729
        self.input_is_parallel = input_is_parallel
        self.reduce_results = reduce_results

        # Divide the weight matrix along the last dimension.
        self.tp_size = get_tensor_model_parallel_world_size()
        self.input_size_per_partition = divide(input_size, self.tp_size)
730
        assert self.quant_method is not None
731
732
733
734
735
736
737
738
        self.quant_method.create_weights(
            layer=self,
            input_size_per_partition=self.input_size_per_partition,
            output_partition_sizes=[self.output_size],
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
            weight_loader=self.weight_loader)
739
740
741
742
743
744
        if not reduce_results and (bias and not skip_bias_add):
            raise ValueError("When not reduce the results, adding bias to the "
                             "results can lead to incorrect results")

        if bias:
            self.bias = Parameter(
745
                torch.empty(self.output_size, dtype=params_dtype))
746
747
748
749
750
751
752
753
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
        else:
            self.register_parameter("bias", None)

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
754
755
756
757
        # Special case for Fp8 scales.
        fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
                                           None)

758
759
760
761
762
763
764
765
        tp_rank = get_tensor_model_parallel_rank()
        input_dim = getattr(param, "input_dim", None)
        param_data = param.data
        if input_dim is not None:
            shard_size = param_data.shape[input_dim]
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(input_dim, start_idx,
                                                 shard_size)
766

767
768
769
770
771
772
        # Special case for Fp8 scales.
        elif fp8_scales_shard_indexer is not None:
            param_data, loaded_weight = fp8_scales_shard_indexer(param_data,
                                                                 loaded_weight,
                                                                 shard_id=0)

773
774
775
        if fp8_scales_shard_indexer is None and len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)

776
777
778
779
780
781
782
783
784
785
786
787
788
789
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

    def forward(self, input_):
        # Set up backprop all-reduce.
        if self.input_is_parallel:
            input_parallel = input_
        else:
            tp_rank = get_tensor_model_parallel_rank()
            splitted_input = split_tensor_along_last_dim(
                input_, num_partitions=self.tp_size)
            input_parallel = splitted_input[tp_rank].contiguous()

        # Matrix multiply.
790
        assert self.quant_method is not None
791
        output_parallel = self.quant_method.apply(self, input_parallel)
792
793
794
795
796
797
798
799
800
801
802
803
        if self.reduce_results and self.tp_size > 1:
            output_ = tensor_model_parallel_all_reduce(output_parallel)
        else:
            output_ = output_parallel

        if not self.skip_bias_add:
            output = output_ + self.bias if self.bias is not None else output_
            output_bias = None
        else:
            output = output_
            output_bias = self.bias
        return output, output_bias
804
805
806
807
808
809
810
811

    def extra_repr(self) -> str:
        s = f"input_features={self.input_size_per_partition}"
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={self.tp_size}"
        s += f", reduce_results={self.reduce_results}"
        return s