linear.py 35.6 KB
Newer Older
1
from abc import abstractmethod
2
from typing import Dict, List, Optional, Tuple
3
4
5
6
7

import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter

8
9
10
11
12
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size,
                              split_tensor_along_last_dim,
                              tensor_model_parallel_all_gather,
                              tensor_model_parallel_all_reduce)
13
from vllm.logger import init_logger
14
15
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig, QuantizeMethodBase)
16
from vllm.model_executor.utils import set_weight_attrs
gaoqiong's avatar
gaoqiong committed
17

zhuwenwen's avatar
zhuwenwen committed
18
import os
19
20
21
22

logger = init_logger(__name__)


23
24
25
26
27
28
29
30
def adjust_marlin_shard(param, shard_size, shard_offset):
    marlin_tile_size = getattr(param, "marlin_tile_size", None)
    if marlin_tile_size is None:
        return shard_size, shard_offset

    return shard_size * marlin_tile_size, shard_offset * marlin_tile_size


31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def adjust_bitsandbytes_shard(param: Parameter,
                              qkv_offsets: Dict[str, Tuple[int, int]],
                              loaded_shard_id: str) -> Tuple[int, int]:
    """Adjust the quantization offsets and sizes for BitsAndBytes sharding."""

    total, _ = qkv_offsets["total"]
    orig_offset, orig_size = qkv_offsets[loaded_shard_id]

    quantized_total = param.data.shape[0]
    quantized_offset = orig_offset * quantized_total // total
    quantized_size = orig_size * quantized_total // total

    return quantized_size, quantized_offset


46
class LinearMethodBase(QuantizeMethodBase):
47
48
49
    """Base class for different (maybe quantized) linear methods."""

    @abstractmethod
50
51
    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
James Fleming's avatar
James Fleming committed
52
                       output_partition_sizes: List[int], input_size: int,
53
54
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
55
56
        """Create weights for a linear layer. 
           The weights will be set as attributes of the layer.
57

58
59
60
61
62
63
64
65
66
67
        Args:
            layer: The layer that is using the LinearMethodBase factory.
            input_size_per_partition: Size of the weight input dim on rank X.
            output_partition_sizes: Sizes of the output dim of each logical 
                weight on rank X. E.g., output_partition_sizes for QKVLinear
                is a list contains the width of Wq, Wk, Wv on rank X.
            input_size: Size of the input dim of the weight across all ranks.
            output_size: Size of the output dim of the weight across all ranks.
            params_dtype: Datatype of the parameters.
        """
68
69
70
        raise NotImplementedError

    @abstractmethod
71
72
73
74
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
75
76
        """Apply the weights in layer to the input tensor.
        Expects create_weights to have been called before on the layer."""
77
78
79
80
81
82
83
84
85
86
87
88
89
        raise NotImplementedError


class UnquantizedLinearMethod(LinearMethodBase):
    """Linear method without quantization.

    Args:
        separate_bias_add: If true, add bias separately after matrix
                           multiplication.
    """

    def __init__(self, separate_bias_add: bool = False):
        self.separate_bias_add = separate_bias_add
zhuwenwen's avatar
zhuwenwen committed
90
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
gaoqiong's avatar
gaoqiong committed
91
        
92
93
    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
James Fleming's avatar
James Fleming committed
94
                       output_partition_sizes: List[int], input_size: int,
95
96
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
97
        weight = Parameter(torch.empty(sum(output_partition_sizes),
CHU Tianxiang's avatar
CHU Tianxiang committed
98
                                       input_size_per_partition,
99
100
101
                                       dtype=params_dtype),
                           requires_grad=False)
        set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
102
103
        layer.register_parameter("weight", weight)
        set_weight_attrs(weight, extra_weight_attrs)
104

105
106
107
108
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
109
        weight = layer.weight
gaoqiong's avatar
gaoqiong committed
110
        
111
        if self.separate_bias_add:
112
            if bias is not None:
113
114
                return F.linear(x, weight) + bias
            return F.linear(x, weight)
zhuwenwen's avatar
zhuwenwen committed
115
        
zhuwenwen's avatar
zhuwenwen committed
116
        if self.use_llama_nn:
zhuwenwen's avatar
zhuwenwen committed
117
            if bias is not None:
zhuwenwen's avatar
zhuwenwen committed
118
119
120
121
                if len(x.shape) == 2: 
                    return torch.addmm(bias, x, weight)
                else:
                    return torch.matmul(x, weight) + bias
zhuwenwen's avatar
zhuwenwen committed
122
            else:
gaoqiong's avatar
gaoqiong committed
123
                return torch.matmul(x, weight)
zhuwenwen's avatar
zhuwenwen committed
124
125
        else:
            return F.linear(x, weight, bias)
126
127


128
129
class LinearBase(torch.nn.Module):
    """Base linear layer.
130
131
132
133
134
135
136

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        bias: If true, add bias.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
137
        quant_config: Quantization configure.
138
139
140
141
142
143
144
145
    """

    def __init__(
        self,
        input_size: int,
        output_size: int,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
146
        quant_config: Optional[QuantizationConfig] = None,
147
148
149
150
151
152
153
154
155
156
    ):
        super().__init__()

        # Keep input parameters
        self.input_size = input_size
        self.output_size = output_size
        self.skip_bias_add = skip_bias_add
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype
157
        if quant_config is None:
158
159
            self.quant_method: Optional[
                QuantizeMethodBase] = UnquantizedLinearMethod()
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
        else:
            self.quant_method = quant_config.get_quant_method(self)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError


class ReplicatedLinear(LinearBase):
    """Replicated linear layer.

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        bias: If true, add bias.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
        quant_config: Quantization configure.
    """

179
180
181
182
183
184
185
    def __init__(self,
                 input_size: int,
                 output_size: int,
                 bias: bool = True,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
                 quant_config: Optional[QuantizationConfig] = None):
186
187
188
        super().__init__(input_size, output_size, skip_bias_add, params_dtype,
                         quant_config)

189
190
        # All the linear layer supports quant method.
        assert self.quant_method is not None
191
192
193
194
        self.quant_method.create_weights(self, self.input_size,
                                         [self.output_size], self.input_size,
                                         self.output_size, self.params_dtype)

195
196
        if bias:
            self.bias = Parameter(
197
                torch.empty(self.output_size, dtype=self.params_dtype))
198
199
200
201
202
203
            set_weight_attrs(self.bias, {"output_dim": 0})
        else:
            self.register_parameter("bias", None)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        bias = self.bias if not self.skip_bias_add else None
204
        assert self.quant_method is not None
205
        output = self.quant_method.apply(self, x, bias)
206
207
208
        output_bias = self.bias if self.skip_bias_add else None
        return output, output_bias

209
210
211
212
213
214
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        return s

215

216
class ColumnParallelLinear(LinearBase):
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
    """Linear layer with column parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its second dimension as A = [A_1, ..., A_p].

    Args:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y_i = XA_i
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
233
        quant_config: Quantization configure.
James Fleming's avatar
James Fleming committed
234
235
        output_sizes: list of output sizes packed into one output, like for QKV
                       the list would be size 3.
236
237
    """

238
239
240
241
242
243
244
245
246
    def __init__(self,
                 input_size: int,
                 output_size: int,
                 bias: bool = True,
                 gather_output: bool = False,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
                 quant_config: Optional[QuantizationConfig] = None,
                 output_sizes: Optional[List[int]] = None):
247
248
        super().__init__(input_size, output_size, skip_bias_add, params_dtype,
                         quant_config)
249
250

        self.gather_output = gather_output
251

252
253
        # Divide the weight matrix along the last dimension.
        tp_size = get_tensor_model_parallel_world_size()
254
255
256
257
258
259
260
261
262
263
        assert self.quant_method is not None
        self.output_size_per_partition = divide(self.output_size, tp_size)
        self.output_partition_sizes = [self.output_size_per_partition]
        # If QKV or MergedColumn, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = [
                divide(output_size, tp_size)
                for output_size in self.output_sizes
            ]

James Fleming's avatar
James Fleming committed
264
265
        if output_sizes is None:
            output_sizes = [output_size]
266
267
268
269
270
271
272
273
        self.quant_method.create_weights(
            layer=self,
            input_size_per_partition=self.input_size,
            output_partition_sizes=self.output_partition_sizes,
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
            weight_loader=self.weight_loader)
274
275
276
277
278
279
280
281
282
283
284
285
        if bias:
            self.bias = Parameter(
                torch.empty(self.output_size_per_partition,
                            dtype=params_dtype))
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
        else:
            self.register_parameter("bias", None)

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
286
287
288
289
        # Special case for Fp8 scales.
        fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
                                           None)

290
291
292
293
294
295
296
297
        tp_rank = get_tensor_model_parallel_rank()
        output_dim = getattr(param, "output_dim", None)
        param_data = param.data
        if output_dim is not None:
            shard_size = param_data.shape[output_dim]
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                 shard_size)
298
299
300
301
302
303
        # Special case for Fp8 scales.
        elif fp8_scales_shard_indexer is not None:
            param_data, loaded_weight = fp8_scales_shard_indexer(param_data,
                                                                 loaded_weight,
                                                                 shard_id=0)

304
305
306
307
308
309
310
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

    def forward(self, input_):
        bias = self.bias if not self.skip_bias_add else None

        # Matrix multiply.
311
        assert self.quant_method is not None
312
        output_parallel = self.quant_method.apply(self, input_, bias)
313
314
315
316
317
318
319
320
        if self.gather_output:
            # All-gather across the partitions.
            output = tensor_model_parallel_all_gather(output_parallel)
        else:
            output = output_parallel
        output_bias = self.bias if self.skip_bias_add else None
        return output, output_bias

321
322
323
324
325
326
327
328
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size_per_partition}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={get_tensor_model_parallel_world_size()}"
        s += f", gather_output={self.gather_output}"
        return s

329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347

class MergedColumnParallelLinear(ColumnParallelLinear):
    """Packed linear layers with column parallelism.

    Similar to ColumnParallelLinear, but the weight matrix is concatenated
    along the output dimension. When the weight matrix is loaded, the
    different partitions are sharded separately.

    Args:
        input_size: input dimension of the linear layer.
        output_sizes: list of output dimensions of the linear layer.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make the output
                       available to all GPUs, otherwise, every GPU will have
                       its own output.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
348
        quant_config: Quantization configure.
349
350
    """

351
352
353
354
355
356
357
358
    def __init__(self,
                 input_size: int,
                 output_sizes: List[int],
                 bias: bool = True,
                 gather_output: bool = False,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
                 quant_config: Optional[QuantizationConfig] = None):
359
360
361
        self.output_sizes = output_sizes
        tp_size = get_tensor_model_parallel_world_size()
        assert all(output_size % tp_size == 0 for output_size in output_sizes)
362
363
364
365
366
367
368
        super().__init__(input_size=input_size,
                         output_size=sum(output_sizes),
                         bias=bias,
                         gather_output=gather_output,
                         skip_bias_add=skip_bias_add,
                         params_dtype=params_dtype,
                         quant_config=quant_config)
369
370
371
372
373

    def weight_loader(self,
                      param: Parameter,
                      loaded_weight: torch.Tensor,
                      loaded_shard_id: Optional[int] = None):
James Fleming's avatar
James Fleming committed
374

375
376
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
377
        # Special case for AQLM codebooks.
James Fleming's avatar
James Fleming committed
378
        is_metadata = getattr(param, "is_metadata", False)
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398

        param_shard_splitter = getattr(param, "shard_splitter", None)

        if output_dim is not None and param_shard_splitter is not None:
            raise NotImplementedError(
                "We do not currently support output_dim != None and "
                "shard_splitter != None for a parameter. Please open an issue."
            )
        # If a parameter has defined a shard_splitter to be used for
        # the weight, it should be applied before the weight is
        # loaded/copied to the parameter. The shard_splitter applies
        # logic by using the loaded_shard_id to ensure that the loaded
        # param is loaded to the correct location
        # within the parameter defined by the linear method.
        if loaded_shard_id is None and param_shard_splitter is not None:
            raise NotImplementedError(
                "We do not currently support loaded_shard_id == None and "
                "shard_splitter != None for a parameter. Please open an issue."
            )

399
400
401
402
        # Special case for Fp8 scales.
        fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
                                           None)

403
404
405
406
407
408
409
410
411
412
413
414
415
        if loaded_shard_id is None:
            # Loaded weight is already packed.
            if output_dim is None:
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            current_shard_offset = 0
            shard_offsets = []
            for i, output_size in enumerate(self.output_sizes):
                shard_offsets.append((i, current_shard_offset, output_size))
                current_shard_offset += output_size
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
416
                # Special case for Quantization.
417
418
419
420
421
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
                    shard_size = shard_size // param.pack_factor
                    shard_offset = shard_offset // param.pack_factor
422
                    # Special case for Marlin.
423
424
425
                    shard_size, shard_offset = adjust_marlin_shard(
                        param, shard_size, shard_offset)

426
427
428
429
430
431
432
433
434
435
436
                loaded_weight_shard = loaded_weight.narrow(
                    output_dim, shard_offset, shard_size)
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id < len(self.output_sizes)
        tp_rank = get_tensor_model_parallel_rank()
        tp_size = get_tensor_model_parallel_world_size()
        if output_dim is not None:
            shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
            shard_size = self.output_sizes[loaded_shard_id] // tp_size
437
            # Special case for quantization.
438
439
440
441
442
443
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
                shard_size = shard_size // param.pack_factor
                shard_offset = shard_offset // param.pack_factor
444
                # Special case for Marlin.
445
446
                shard_size, shard_offset = adjust_marlin_shard(
                    param, shard_size, shard_offset)
gaoqiong's avatar
gaoqiong committed
447

448
449
450
451
452
            use_bitsandbytes = getattr(param, "use_bitsandbytes", False)
            if use_bitsandbytes:
                shard_size = loaded_weight.shape[output_dim]
                shard_offset = loaded_weight.shape[output_dim] * \
                    loaded_shard_id
453

gaoqiong's avatar
gaoqiong committed
454
455
            param_data = param_data.narrow(output_dim, shard_offset,
                                           shard_size)
456
457
458
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                 shard_size)
459
        # Special case for AQLM codebooks.
James Fleming's avatar
James Fleming committed
460
461
462
463
464
        elif is_metadata:
            # metadata indicates fixed size concatenated along dim 0
            shard_size = loaded_weight.shape[0]
            shard_offset = loaded_shard_id * shard_size
            param_data = param_data.narrow(0, shard_offset, shard_size)
465
466
467
468
469
470
471

        # If a param_shard_splitter is defined by the LinearMethod, use it.
        elif param_shard_splitter is not None:
            logical_widths = getattr(param, "logical_widths", None)
            param_data, loaded_weight = param_shard_splitter(
                param_data, loaded_weight, loaded_shard_id, logical_widths)

472
473
474
475
476
        # Special case for Fp8 scales.
        elif fp8_scales_shard_indexer is not None:
            param_data, loaded_weight = fp8_scales_shard_indexer(
                param_data, loaded_weight, loaded_shard_id)

477
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
478
479
480
481
482
483
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "MergedColumnParallelLinear, assume the weight is "
                    "the same for all partitions.")
484
485
486
487
488
489
490

        if fp8_scales_shard_indexer is None:
            if len(param_data.shape) == 0:
                param_data = param_data.reshape(1)

            if len(loaded_weight.shape) == 0:
                loaded_weight = loaded_weight.reshape(1)
491

gaoqiong's avatar
gaoqiong committed
492
493
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516


class QKVParallelLinear(ColumnParallelLinear):
    """Linear layers for the attention's QKV transformation.

    Linear layers for the linear transformation of the query, key, and value
    vectors in the attention layer. The weight matrix is concatenated along
    the output dimension. The layer is parallelized along the head dimension.
    When the number of key/value heads is smaller than the number of query
    heads (e.g., multi-query/grouped-query attention), the key/value head may
    be replicated while the query heads are partitioned.

    Args:
        hidden_size: input hidden state size of the transformer.
        head_size: size of each attention head.
        total_num_heads: total number of attention query heads.
        total_num_kv_heads: total number of attention key/value heads. If
                            None, assume total_num_kv_heads = total_num_heads.
        bias: If true, add bias.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
517
        quant_config: Quantization configure.
518
519
    """

520
521
522
523
524
525
526
527
528
    def __init__(self,
                 hidden_size: int,
                 head_size: int,
                 total_num_heads: int,
                 total_num_kv_heads: Optional[int] = None,
                 bias: bool = True,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
                 quant_config: Optional[QuantizationConfig] = None):
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
        self.hidden_size = hidden_size
        self.head_size = head_size
        self.total_num_heads = total_num_heads
        if total_num_kv_heads is None:
            total_num_kv_heads = total_num_heads
        self.total_num_kv_heads = total_num_kv_heads
        # Divide the weight matrix along the last dimension.
        tp_size = get_tensor_model_parallel_world_size()
        self.num_heads = divide(self.total_num_heads, tp_size)
        if tp_size >= self.total_num_kv_heads:
            self.num_kv_heads = 1
            self.num_kv_head_replicas = divide(tp_size,
                                               self.total_num_kv_heads)
        else:
            self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
            self.num_kv_head_replicas = 1
        input_size = self.hidden_size
        output_size = (self.num_heads +
                       2 * self.num_kv_heads) * tp_size * self.head_size
548
549
550
551
        self.output_sizes = [
            self.num_heads * self.head_size * tp_size,  # q_proj
            self.num_kv_heads * self.head_size * tp_size,  # k_proj
            self.num_kv_heads * self.head_size * tp_size,  # v_proj 
James Fleming's avatar
James Fleming committed
552
        ]
gaoqiong's avatar
gaoqiong committed
553

554
555
556
557
558
559
560
        super().__init__(input_size=input_size,
                         output_size=output_size,
                         bias=bias,
                         gather_output=False,
                         skip_bias_add=skip_bias_add,
                         params_dtype=params_dtype,
                         quant_config=quant_config)
561
562
563
564
565
566
567

    def weight_loader(self,
                      param: Parameter,
                      loaded_weight: torch.Tensor,
                      loaded_shard_id: Optional[str] = None):
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
568
        # Special case for AQLM codebooks.
James Fleming's avatar
James Fleming committed
569
        is_metadata = getattr(param, "is_metadata", False)
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589

        param_shard_splitter = getattr(param, "shard_splitter", None)

        if output_dim is not None and param_shard_splitter is not None:
            raise NotImplementedError(
                "We do not currently support output_dim != None and "
                "shard_splitter != None for a parameter. Please open an issue."
            )
        # If a parameter has defined a shard_splitter to be used for
        # the weight, it should be applied before the weight is
        # loaded/copied to the parameter. The shard_splitter applies
        # logic by using the loaded_shard_id to ensure that the loaded
        # param is loaded to the correct location
        # within the parameter defined by the linear method.
        if loaded_shard_id is None and param_shard_splitter is not None:
            raise NotImplementedError(
                "We do not currently support loaded_shard_id == None and "
                "shard_splitter != None for a parameter. Please open an issue."
            )

590
591
592
        # Special case for Fp8 scales.
        fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
                                           None)
593

594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
        if loaded_shard_id is None:
            # Loaded weight is already packed.
            if output_dim is None:
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            shard_offsets = [
                # (shard_id, shard_offset, shard_size)
                ("q", 0, self.total_num_heads * self.head_size),
                ("k", self.total_num_heads * self.head_size,
                 self.total_num_kv_heads * self.head_size),
                ("v", (self.total_num_heads + self.total_num_kv_heads) *
                 self.head_size, self.total_num_kv_heads * self.head_size),
            ]
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
610
                # Special case for Quantized Weights.
611
612
613
614
615
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
                    shard_size = shard_size // param.pack_factor
                    shard_offset = shard_offset // param.pack_factor
616

617
                    # Special case for Marlin.
618
619
620
                    shard_size, shard_offset = adjust_marlin_shard(
                        param, shard_size, shard_offset)

621
622
623
624
625
626
627
                loaded_weight_shard = loaded_weight.narrow(
                    output_dim, shard_offset, shard_size)
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        tp_rank = get_tensor_model_parallel_rank()
        assert loaded_shard_id in ["q", "k", "v"]
628
629

        # If output dim is defined, use the default loading process.
630
631
632
633
634
635
636
637
638
639
640
        if output_dim is not None:
            if loaded_shard_id == "q":
                shard_offset = 0
                shard_size = self.num_heads * self.head_size
            elif loaded_shard_id == "k":
                shard_offset = self.num_heads * self.head_size
                shard_size = self.num_kv_heads * self.head_size
            elif loaded_shard_id == "v":
                shard_offset = (self.num_heads +
                                self.num_kv_heads) * self.head_size
                shard_size = self.num_kv_heads * self.head_size
641
            # Special case for Quantized Weights.
642
643
644
645
646
647
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
                shard_size = shard_size // param.pack_factor
                shard_offset = shard_offset // param.pack_factor
648

649
                # Special case for Marlin.
650
651
652
                shard_size, shard_offset = adjust_marlin_shard(
                    param, shard_size, shard_offset)

653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
            use_bitsandbytes = getattr(param, "use_bitsandbytes", False)
            if use_bitsandbytes:
                orig_qkv_offsets = {
                    "q": (0, self.num_heads * self.head_size),
                    "k": (self.num_heads * self.head_size,
                          self.num_kv_heads * self.head_size),
                    "v":
                    ((self.num_heads + self.num_kv_heads) * self.head_size,
                     self.num_kv_heads * self.head_size),
                    "total":
                    ((self.num_heads + 2 * self.num_kv_heads) * self.head_size,
                     0)
                }
                shard_size, shard_offset = adjust_bitsandbytes_shard(
                    param, orig_qkv_offsets, loaded_shard_id)
gaoqiong's avatar
gaoqiong committed
668
669

            param_data = param_data.narrow(output_dim, shard_offset,
zhuwenwen's avatar
zhuwenwen committed
670
                                           shard_size)
zhuwenwen's avatar
zhuwenwen committed
671
            if loaded_shard_id == "q":
672
673
674
                shard_id = tp_rank
            else:
                shard_id = tp_rank // self.num_kv_head_replicas
675
676
677
            start_idx = shard_id * shard_size
            loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                 shard_size)
678
        # Special case for for AQLM codebooks.
James Fleming's avatar
James Fleming committed
679
680
681
682
683
684
        elif is_metadata:
            # metadata indicates fixed size concatenated along dim 0
            shard_size = loaded_weight.shape[0]
            shard_index = ["q", "k", "v"].index(loaded_shard_id)
            param_data = param_data.narrow(0, shard_index * shard_size,
                                           shard_size)
685
686
687
688
689
690
        # If a param_shard_splitter is defined by the LinearMethod, use it.
        elif param_shard_splitter is not None:
            logical_widths = getattr(param, "logical_widths", None)
            param_data, loaded_weight = param_shard_splitter(
                param_data, loaded_weight, loaded_shard_id, logical_widths)

691
692
693
694
        # Special case for Fp8 scales.
        elif fp8_scales_shard_indexer is not None:
            param_data, loaded_weight = fp8_scales_shard_indexer(
                param_data, loaded_weight, loaded_shard_id)
695
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
696
697
698
699
700
701
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "QKVParallelLinear, assume the weight is the same "
                    "for all partitions.")
gaoqiong's avatar
gaoqiong committed
702

703
704
705
706
707
        if len(param_data.shape) == 0:
            param_data = param_data.reshape(1)

        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)
gaoqiong's avatar
gaoqiong committed
708
709
710

        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)
711
712


713
class RowParallelLinear(LinearBase):
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
    """Linear layer with row parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its first dimension and X along its second dimension as:
               -   -
              | A_1 |
              | .   |
          A = | .   |        X = [X_1, ..., X_p]
              | .   |
              | A_p |
               -   -
    Arguments:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias. Note that bias is not parallelized.
        input_is_parallel: If true, we assume that the input is already
                           split across the GPUs and we do not split
                           again.
        skip_bias_add: This was added to enable performance optimization where
                       bias can be fused with other element-wise operations.
                       We skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
736
        quant_config: Quantization configure.
737
738
    """

739
740
741
742
743
744
745
746
747
    def __init__(self,
                 input_size: int,
                 output_size: int,
                 bias: bool = True,
                 input_is_parallel: bool = True,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
                 reduce_results: bool = True,
                 quant_config: Optional[QuantizationConfig] = None):
748
749
750
        super().__init__(input_size, output_size, skip_bias_add, params_dtype,
                         quant_config)

751
752
753
754
755
756
        self.input_is_parallel = input_is_parallel
        self.reduce_results = reduce_results

        # Divide the weight matrix along the last dimension.
        self.tp_size = get_tensor_model_parallel_world_size()
        self.input_size_per_partition = divide(input_size, self.tp_size)
757
        assert self.quant_method is not None
758
759
760
761
762
763
764
765
        self.quant_method.create_weights(
            layer=self,
            input_size_per_partition=self.input_size_per_partition,
            output_partition_sizes=[self.output_size],
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
            weight_loader=self.weight_loader)
766
767
768
769
770
771
        if not reduce_results and (bias and not skip_bias_add):
            raise ValueError("When not reduce the results, adding bias to the "
                             "results can lead to incorrect results")

        if bias:
            self.bias = Parameter(
772
                torch.empty(self.output_size, dtype=params_dtype))
773
774
775
776
777
778
779
780
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
        else:
            self.register_parameter("bias", None)

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
781
782
783
784
        # Special case for Fp8 scales.
        fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
                                           None)

785
786
787
788
789
790
791
792
        tp_rank = get_tensor_model_parallel_rank()
        input_dim = getattr(param, "input_dim", None)
        param_data = param.data
        if input_dim is not None:
            shard_size = param_data.shape[input_dim]
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(input_dim, start_idx,
                                                 shard_size)
793

794
795
796
797
798
799
        # Special case for Fp8 scales.
        elif fp8_scales_shard_indexer is not None:
            param_data, loaded_weight = fp8_scales_shard_indexer(param_data,
                                                                 loaded_weight,
                                                                 shard_id=0)

800
801
802
        if fp8_scales_shard_indexer is None and len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)

803
        assert param_data.shape == loaded_weight.shape
gaoqiong's avatar
gaoqiong committed
804
        param_data.copy_(loaded_weight)
805
806
807
808
809
810
811
812
813
814
815
816

    def forward(self, input_):
        # Set up backprop all-reduce.
        if self.input_is_parallel:
            input_parallel = input_
        else:
            tp_rank = get_tensor_model_parallel_rank()
            splitted_input = split_tensor_along_last_dim(
                input_, num_partitions=self.tp_size)
            input_parallel = splitted_input[tp_rank].contiguous()

        # Matrix multiply.
817
        assert self.quant_method is not None
818
        output_parallel = self.quant_method.apply(self, input_parallel)
819
820
821
822
823
824
825
826
827
828
829
830
        if self.reduce_results and self.tp_size > 1:
            output_ = tensor_model_parallel_all_reduce(output_parallel)
        else:
            output_ = output_parallel

        if not self.skip_bias_add:
            output = output_ + self.bias if self.bias is not None else output_
            output_bias = None
        else:
            output = output_
            output_bias = self.bias
        return output, output_bias
831
832
833
834
835
836
837
838

    def extra_repr(self) -> str:
        s = f"input_features={self.input_size_per_partition}"
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={self.tp_size}"
        s += f", reduce_results={self.reduce_results}"
        return s