linear.py 34.2 KB
Newer Older
1
from abc import abstractmethod
2
from typing import Dict, List, Optional, Tuple
3
4
5
6
7

import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter

8
9
10
11
12
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size,
                              split_tensor_along_last_dim,
                              tensor_model_parallel_all_gather,
                              tensor_model_parallel_all_reduce)
13
from vllm.logger import init_logger
14
15
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig, QuantizeMethodBase)
16
17
18
19
20
from vllm.model_executor.utils import set_weight_attrs

logger = init_logger(__name__)


21
22
23
24
25
26
27
28
def adjust_marlin_shard(param, shard_size, shard_offset):
    marlin_tile_size = getattr(param, "marlin_tile_size", None)
    if marlin_tile_size is None:
        return shard_size, shard_offset

    return shard_size * marlin_tile_size, shard_offset * marlin_tile_size


29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def adjust_bitsandbytes_shard(param: Parameter,
                              qkv_offsets: Dict[str, Tuple[int, int]],
                              loaded_shard_id: str) -> Tuple[int, int]:
    """Adjust the quantization offsets and sizes for BitsAndBytes sharding."""

    total, _ = qkv_offsets["total"]
    orig_offset, orig_size = qkv_offsets[loaded_shard_id]

    quantized_total = param.data.shape[0]
    quantized_offset = orig_offset * quantized_total // total
    quantized_size = orig_size * quantized_total // total

    return quantized_size, quantized_offset


44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
    """For fused modules (QKV and MLP) we have an array of length
    N that holds 1 scale for each "logical" matrix. So the param
    is an array of length N. The loaded_weight corresponds to 
    one of the shards on disk. Here, we slice the param based on 
    the shard_id for loading.
    """
    qkv_idxs = {"q": 0, "k": 1, "v": 2}

    if isinstance(shard_id, str):
        shard_id = qkv_idxs[shard_id]
    elif not isinstance(shard_id, int):
        raise ValueError(f"Unknown Shard Id {shard_id}")

    # AutoFP8 scales do not have a shape
    # compressed-tensors scales do have a shape
    if len(loaded_weight.shape) != 0:
        assert loaded_weight.shape[0] == 1
        loaded_weight = loaded_weight[0]

    return param[shard_id], loaded_weight


67
class LinearMethodBase(QuantizeMethodBase):
68
69
70
    """Base class for different (maybe quantized) linear methods."""

    @abstractmethod
71
72
    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
James Fleming's avatar
James Fleming committed
73
                       output_partition_sizes: List[int], input_size: int,
74
75
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
76
77
        """Create weights for a linear layer. 
           The weights will be set as attributes of the layer.
78

79
80
81
82
83
84
85
86
87
88
        Args:
            layer: The layer that is using the LinearMethodBase factory.
            input_size_per_partition: Size of the weight input dim on rank X.
            output_partition_sizes: Sizes of the output dim of each logical 
                weight on rank X. E.g., output_partition_sizes for QKVLinear
                is a list contains the width of Wq, Wk, Wv on rank X.
            input_size: Size of the input dim of the weight across all ranks.
            output_size: Size of the output dim of the weight across all ranks.
            params_dtype: Datatype of the parameters.
        """
89
90
91
        raise NotImplementedError

    @abstractmethod
92
93
94
95
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
96
97
        """Apply the weights in layer to the input tensor.
        Expects create_weights to have been called before on the layer."""
98
99
100
101
        raise NotImplementedError


class UnquantizedLinearMethod(LinearMethodBase):
102
    """Linear method without quantization."""
103

104
105
    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
James Fleming's avatar
James Fleming committed
106
                       output_partition_sizes: List[int], input_size: int,
107
108
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
109
        weight = Parameter(torch.empty(sum(output_partition_sizes),
CHU Tianxiang's avatar
CHU Tianxiang committed
110
                                       input_size_per_partition,
111
112
113
                                       dtype=params_dtype),
                           requires_grad=False)
        set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
114
115
        layer.register_parameter("weight", weight)
        set_weight_attrs(weight, extra_weight_attrs)
116

117
118
119
120
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
121
122

        return F.linear(x, layer.weight, bias)
123
124


125
126
class LinearBase(torch.nn.Module):
    """Base linear layer.
127
128
129
130
131
132
133

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        bias: If true, add bias.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
134
        quant_config: Quantization configure.
135
136
137
138
139
140
141
142
    """

    def __init__(
        self,
        input_size: int,
        output_size: int,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
143
        quant_config: Optional[QuantizationConfig] = None,
144
        prefix: str = "",
145
146
147
148
149
150
151
152
153
154
    ):
        super().__init__()

        # Keep input parameters
        self.input_size = input_size
        self.output_size = output_size
        self.skip_bias_add = skip_bias_add
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype
155
        if quant_config is None:
156
157
            self.quant_method: Optional[
                QuantizeMethodBase] = UnquantizedLinearMethod()
158
        else:
159
160
            self.quant_method = quant_config.get_quant_method(self,
                                                              prefix=prefix)
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError


class ReplicatedLinear(LinearBase):
    """Replicated linear layer.

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        bias: If true, add bias.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
        quant_config: Quantization configure.
176
177
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
178
179
    """

180
181
182
183
184
185
    def __init__(self,
                 input_size: int,
                 output_size: int,
                 bias: bool = True,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
186
                 quant_config: Optional[QuantizationConfig] = None,
187
188
189
190
191
192
193
                 prefix: str = ""):
        super().__init__(input_size,
                         output_size,
                         skip_bias_add,
                         params_dtype,
                         quant_config,
                         prefix=prefix)
194

195
196
        # All the linear layer supports quant method.
        assert self.quant_method is not None
197
198
199
200
201
        self.quant_method.create_weights(self,
                                         self.input_size, [self.output_size],
                                         self.input_size,
                                         self.output_size,
                                         self.params_dtype,
202
                                         weight_loader=self.weight_loader,
203
                                         prefix=prefix)
204

205
206
        if bias:
            self.bias = Parameter(
207
                torch.empty(self.output_size, dtype=self.params_dtype))
208
209
210
211
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
212
213
214
        else:
            self.register_parameter("bias", None)

215
216
217
218
219
220
221
222
223
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        # If the weight on disk does not have a shape, give it one
        # (such scales for AutoFp8).
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)

        assert param.size() == loaded_weight.size()
        param.data.copy_(loaded_weight)

224
225
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        bias = self.bias if not self.skip_bias_add else None
226
        assert self.quant_method is not None
227
        output = self.quant_method.apply(self, x, bias)
228
229
230
        output_bias = self.bias if self.skip_bias_add else None
        return output, output_bias

231
232
233
234
235
236
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        return s

237

238
class ColumnParallelLinear(LinearBase):
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
    """Linear layer with column parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its second dimension as A = [A_1, ..., A_p].

    Args:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y_i = XA_i
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
255
        quant_config: Quantization configure.
James Fleming's avatar
James Fleming committed
256
257
        output_sizes: list of output sizes packed into one output, like for QKV
                       the list would be size 3.
258
259
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj) 
260
261
    """

262
263
264
265
266
267
268
269
    def __init__(self,
                 input_size: int,
                 output_size: int,
                 bias: bool = True,
                 gather_output: bool = False,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
                 quant_config: Optional[QuantizationConfig] = None,
270
                 output_sizes: Optional[List[int]] = None,
271
                 prefix: str = ""):
272
        super().__init__(input_size, output_size, skip_bias_add, params_dtype,
273
                         quant_config, prefix)
274
275

        self.gather_output = gather_output
276

277
278
        # Divide the weight matrix along the last dimension.
        tp_size = get_tensor_model_parallel_world_size()
279
280
281
282
283
284
285
286
287
288
        assert self.quant_method is not None
        self.output_size_per_partition = divide(self.output_size, tp_size)
        self.output_partition_sizes = [self.output_size_per_partition]
        # If QKV or MergedColumn, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = [
                divide(output_size, tp_size)
                for output_size in self.output_sizes
            ]

James Fleming's avatar
James Fleming committed
289
290
        if output_sizes is None:
            output_sizes = [output_size]
291
292
293
294
295
296
297
        self.quant_method.create_weights(
            layer=self,
            input_size_per_partition=self.input_size,
            output_partition_sizes=self.output_partition_sizes,
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
298
299
            weight_loader=self.weight_loader,
            prefix=prefix)
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
        if bias:
            self.bias = Parameter(
                torch.empty(self.output_size_per_partition,
                            dtype=params_dtype))
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
        else:
            self.register_parameter("bias", None)

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        tp_rank = get_tensor_model_parallel_rank()
        output_dim = getattr(param, "output_dim", None)
        param_data = param.data
        if output_dim is not None:
            shard_size = param_data.shape[output_dim]
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                 shard_size)
320
321
322
323
324

        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)
325

326
327
328
329
330
331
332
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

    def forward(self, input_):
        bias = self.bias if not self.skip_bias_add else None

        # Matrix multiply.
333
        assert self.quant_method is not None
334
        output_parallel = self.quant_method.apply(self, input_, bias)
335
336
337
338
339
340
341
342
        if self.gather_output:
            # All-gather across the partitions.
            output = tensor_model_parallel_all_gather(output_parallel)
        else:
            output = output_parallel
        output_bias = self.bias if self.skip_bias_add else None
        return output, output_bias

343
344
345
346
347
348
349
350
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size_per_partition}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={get_tensor_model_parallel_world_size()}"
        s += f", gather_output={self.gather_output}"
        return s

351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369

class MergedColumnParallelLinear(ColumnParallelLinear):
    """Packed linear layers with column parallelism.

    Similar to ColumnParallelLinear, but the weight matrix is concatenated
    along the output dimension. When the weight matrix is loaded, the
    different partitions are sharded separately.

    Args:
        input_size: input dimension of the linear layer.
        output_sizes: list of output dimensions of the linear layer.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make the output
                       available to all GPUs, otherwise, every GPU will have
                       its own output.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
370
        quant_config: Quantization configure.
371
372
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
373
374
    """

375
376
377
378
379
380
381
    def __init__(self,
                 input_size: int,
                 output_sizes: List[int],
                 bias: bool = True,
                 gather_output: bool = False,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
382
                 quant_config: Optional[QuantizationConfig] = None,
383
                 prefix: str = ""):
384
385
386
        self.output_sizes = output_sizes
        tp_size = get_tensor_model_parallel_world_size()
        assert all(output_size % tp_size == 0 for output_size in output_sizes)
387
388
389
390
391
392
        super().__init__(input_size=input_size,
                         output_size=sum(output_sizes),
                         bias=bias,
                         gather_output=gather_output,
                         skip_bias_add=skip_bias_add,
                         params_dtype=params_dtype,
393
394
                         quant_config=quant_config,
                         prefix=prefix)
395
396
397
398
399

    def weight_loader(self,
                      param: Parameter,
                      loaded_weight: torch.Tensor,
                      loaded_shard_id: Optional[int] = None):
James Fleming's avatar
James Fleming committed
400

401
402
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
403
        # Special case for AQLM codebooks.
James Fleming's avatar
James Fleming committed
404
        is_metadata = getattr(param, "is_metadata", False)
405
406
        # Special case for per-tensor scale to load scalar into fused array.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
407

408
        if loaded_shard_id is None:
409
            # Loaded weight is already fused on disk (qkv/mlp).
410
            if output_dim is None:
411
                if needs_scalar_to_array:
412
413
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
                        param_data, loaded_weight, 0)
414

415
416
417
418
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            current_shard_offset = 0
419
            shard_offsets: List[Tuple[int, int, int]] = []
420
421
422
423
424
            for i, output_size in enumerate(self.output_sizes):
                shard_offsets.append((i, current_shard_offset, output_size))
                current_shard_offset += output_size
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
425
                # Special case for Quantization.
426
427
428
429
430
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
                    shard_size = shard_size // param.pack_factor
                    shard_offset = shard_offset // param.pack_factor
431
                    # Special case for Marlin.
432
433
434
                    shard_size, shard_offset = adjust_marlin_shard(
                        param, shard_size, shard_offset)

435
436
437
438
439
440
441
442
443
444
445
                loaded_weight_shard = loaded_weight.narrow(
                    output_dim, shard_offset, shard_size)
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id < len(self.output_sizes)
        tp_rank = get_tensor_model_parallel_rank()
        tp_size = get_tensor_model_parallel_world_size()
        if output_dim is not None:
            shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
            shard_size = self.output_sizes[loaded_shard_id] // tp_size
446
            # Special case for quantization.
447
448
449
450
451
452
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
                shard_size = shard_size // param.pack_factor
                shard_offset = shard_offset // param.pack_factor
453
                # Special case for Marlin.
454
455
456
                shard_size, shard_offset = adjust_marlin_shard(
                    param, shard_size, shard_offset)

457
458
459
460
461
462
            use_bitsandbytes = getattr(param, "use_bitsandbytes", False)
            if use_bitsandbytes:
                shard_size = loaded_weight.shape[output_dim]
                shard_offset = loaded_weight.shape[output_dim] * \
                    loaded_shard_id

463
464
465
466
467
            param_data = param_data.narrow(output_dim, shard_offset,
                                           shard_size)
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                 shard_size)
468
        # Special case for AQLM codebooks.
James Fleming's avatar
James Fleming committed
469
470
471
472
473
        elif is_metadata:
            # metadata indicates fixed size concatenated along dim 0
            shard_size = loaded_weight.shape[0]
            shard_offset = loaded_shard_id * shard_size
            param_data = param_data.narrow(0, shard_offset, shard_size)
474

475
476
477
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
478
479
                param_data, loaded_weight, loaded_shard_id)

480
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
481
482
483
484
485
486
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "MergedColumnParallelLinear, assume the weight is "
                    "the same for all partitions.")
487

488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)


class QKVParallelLinear(ColumnParallelLinear):
    """Linear layers for the attention's QKV transformation.

    Linear layers for the linear transformation of the query, key, and value
    vectors in the attention layer. The weight matrix is concatenated along
    the output dimension. The layer is parallelized along the head dimension.
    When the number of key/value heads is smaller than the number of query
    heads (e.g., multi-query/grouped-query attention), the key/value head may
    be replicated while the query heads are partitioned.

    Args:
        hidden_size: input hidden state size of the transformer.
        head_size: size of each attention head.
        total_num_heads: total number of attention query heads.
        total_num_kv_heads: total number of attention key/value heads. If
                            None, assume total_num_kv_heads = total_num_heads.
        bias: If true, add bias.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
513
        quant_config: Quantization configure.
514
515
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
516
517
    """

518
519
520
521
522
523
524
525
    def __init__(self,
                 hidden_size: int,
                 head_size: int,
                 total_num_heads: int,
                 total_num_kv_heads: Optional[int] = None,
                 bias: bool = True,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
526
                 quant_config: Optional[QuantizationConfig] = None,
527
                 prefix: str = ""):
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
        self.hidden_size = hidden_size
        self.head_size = head_size
        self.total_num_heads = total_num_heads
        if total_num_kv_heads is None:
            total_num_kv_heads = total_num_heads
        self.total_num_kv_heads = total_num_kv_heads
        # Divide the weight matrix along the last dimension.
        tp_size = get_tensor_model_parallel_world_size()
        self.num_heads = divide(self.total_num_heads, tp_size)
        if tp_size >= self.total_num_kv_heads:
            self.num_kv_heads = 1
            self.num_kv_head_replicas = divide(tp_size,
                                               self.total_num_kv_heads)
        else:
            self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
            self.num_kv_head_replicas = 1
        input_size = self.hidden_size
        output_size = (self.num_heads +
                       2 * self.num_kv_heads) * tp_size * self.head_size
547
548
549
550
        self.output_sizes = [
            self.num_heads * self.head_size * tp_size,  # q_proj
            self.num_kv_heads * self.head_size * tp_size,  # k_proj
            self.num_kv_heads * self.head_size * tp_size,  # v_proj 
James Fleming's avatar
James Fleming committed
551
552
        ]

553
554
555
556
557
558
        super().__init__(input_size=input_size,
                         output_size=output_size,
                         bias=bias,
                         gather_output=False,
                         skip_bias_add=skip_bias_add,
                         params_dtype=params_dtype,
559
560
                         quant_config=quant_config,
                         prefix=prefix)
561
562
563
564
565
566
567

    def weight_loader(self,
                      param: Parameter,
                      loaded_weight: torch.Tensor,
                      loaded_shard_id: Optional[str] = None):
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
568
        # Special case for AQLM codebooks.
James Fleming's avatar
James Fleming committed
569
        is_metadata = getattr(param, "is_metadata", False)
570

571
572
        # Special case for per-tensor scales in fused case.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
573

574
        if loaded_shard_id is None:
575
            # Loaded weight is already fused on disk (qkv/mlp).
576
            if output_dim is None:
577
                if needs_scalar_to_array:
578
579
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
                        param_data, loaded_weight, 0)
580

581
582
583
584
585
586
587
588
589
590
591
592
593
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            shard_offsets = [
                # (shard_id, shard_offset, shard_size)
                ("q", 0, self.total_num_heads * self.head_size),
                ("k", self.total_num_heads * self.head_size,
                 self.total_num_kv_heads * self.head_size),
                ("v", (self.total_num_heads + self.total_num_kv_heads) *
                 self.head_size, self.total_num_kv_heads * self.head_size),
            ]
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
594
                # Special case for Quantized Weights.
595
596
597
598
599
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
                    shard_size = shard_size // param.pack_factor
                    shard_offset = shard_offset // param.pack_factor
600

601
                    # Special case for Marlin.
602
603
604
                    shard_size, shard_offset = adjust_marlin_shard(
                        param, shard_size, shard_offset)

605
606
607
608
609
610
611
                loaded_weight_shard = loaded_weight.narrow(
                    output_dim, shard_offset, shard_size)
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        tp_rank = get_tensor_model_parallel_rank()
        assert loaded_shard_id in ["q", "k", "v"]
612
613

        # If output dim is defined, use the default loading process.
614
615
616
617
618
619
620
621
622
623
624
        if output_dim is not None:
            if loaded_shard_id == "q":
                shard_offset = 0
                shard_size = self.num_heads * self.head_size
            elif loaded_shard_id == "k":
                shard_offset = self.num_heads * self.head_size
                shard_size = self.num_kv_heads * self.head_size
            elif loaded_shard_id == "v":
                shard_offset = (self.num_heads +
                                self.num_kv_heads) * self.head_size
                shard_size = self.num_kv_heads * self.head_size
625
            # Special case for Quantized Weights.
626
627
628
629
630
631
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
                shard_size = shard_size // param.pack_factor
                shard_offset = shard_offset // param.pack_factor
632

633
                # Special case for Marlin.
634
635
636
                shard_size, shard_offset = adjust_marlin_shard(
                    param, shard_size, shard_offset)

637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
            use_bitsandbytes = getattr(param, "use_bitsandbytes", False)
            if use_bitsandbytes:
                orig_qkv_offsets = {
                    "q": (0, self.num_heads * self.head_size),
                    "k": (self.num_heads * self.head_size,
                          self.num_kv_heads * self.head_size),
                    "v":
                    ((self.num_heads + self.num_kv_heads) * self.head_size,
                     self.num_kv_heads * self.head_size),
                    "total":
                    ((self.num_heads + 2 * self.num_kv_heads) * self.head_size,
                     0)
                }
                shard_size, shard_offset = adjust_bitsandbytes_shard(
                    param, orig_qkv_offsets, loaded_shard_id)

653
654
            param_data = param_data.narrow(output_dim, shard_offset,
                                           shard_size)
655
656
657
658
            if loaded_shard_id == "q":
                shard_id = tp_rank
            else:
                shard_id = tp_rank // self.num_kv_head_replicas
659
660
661
            start_idx = shard_id * shard_size
            loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                 shard_size)
662
        # Special case for for AQLM codebooks.
James Fleming's avatar
James Fleming committed
663
664
665
666
667
668
        elif is_metadata:
            # metadata indicates fixed size concatenated along dim 0
            shard_size = loaded_weight.shape[0]
            shard_index = ["q", "k", "v"].index(loaded_shard_id)
            param_data = param_data.narrow(0, shard_index * shard_size,
                                           shard_size)
669
670
671
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
672
                param_data, loaded_weight, loaded_shard_id)
673
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
674
675
676
677
678
679
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "QKVParallelLinear, assume the weight is the same "
                    "for all partitions.")
680

681
682
683
684
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)


685
class RowParallelLinear(LinearBase):
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
    """Linear layer with row parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its first dimension and X along its second dimension as:
               -   -
              | A_1 |
              | .   |
          A = | .   |        X = [X_1, ..., X_p]
              | .   |
              | A_p |
               -   -
    Arguments:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias. Note that bias is not parallelized.
        input_is_parallel: If true, we assume that the input is already
                           split across the GPUs and we do not split
                           again.
        skip_bias_add: This was added to enable performance optimization where
                       bias can be fused with other element-wise operations.
                       We skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
708
        quant_config: Quantization configure.
709
710
    """

711
712
713
714
715
716
717
718
    def __init__(self,
                 input_size: int,
                 output_size: int,
                 bias: bool = True,
                 input_is_parallel: bool = True,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
                 reduce_results: bool = True,
719
                 quant_config: Optional[QuantizationConfig] = None,
720
                 prefix: str = ""):
721
        super().__init__(input_size, output_size, skip_bias_add, params_dtype,
722
                         quant_config, prefix)
723

724
725
726
727
        self.input_is_parallel = input_is_parallel
        self.reduce_results = reduce_results

        # Divide the weight matrix along the last dimension.
728
        self.tp_rank = get_tensor_model_parallel_rank()
729
730
        self.tp_size = get_tensor_model_parallel_world_size()
        self.input_size_per_partition = divide(input_size, self.tp_size)
731
        assert self.quant_method is not None
732
733
734
735
736
737
738
        self.quant_method.create_weights(
            layer=self,
            input_size_per_partition=self.input_size_per_partition,
            output_partition_sizes=[self.output_size],
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
739
740
            weight_loader=self.weight_loader,
            prefix=prefix)
741
742
743
744
745
746
        if not reduce_results and (bias and not skip_bias_add):
            raise ValueError("When not reduce the results, adding bias to the "
                             "results can lead to incorrect results")

        if bias:
            self.bias = Parameter(
747
                torch.empty(self.output_size, dtype=params_dtype))
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
        else:
            self.register_parameter("bias", None)

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        tp_rank = get_tensor_model_parallel_rank()
        input_dim = getattr(param, "input_dim", None)
        param_data = param.data
        if input_dim is not None:
            shard_size = param_data.shape[input_dim]
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(input_dim, start_idx,
                                                 shard_size)
764

765
766
767
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
768
769
            loaded_weight = loaded_weight.reshape(1)

770
771
772
773
774
775
776
777
778
779
780
781
782
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

    def forward(self, input_):
        if self.input_is_parallel:
            input_parallel = input_
        else:
            tp_rank = get_tensor_model_parallel_rank()
            splitted_input = split_tensor_along_last_dim(
                input_, num_partitions=self.tp_size)
            input_parallel = splitted_input[tp_rank].contiguous()

        # Matrix multiply.
783
        assert self.quant_method is not None
784
785
786
787
788
789
        # Only fuse bias add into GEMM for rank 0 (this ensures that
        # bias will not get added more than once in TP>1 case)
        bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
        output_parallel = self.quant_method.apply(self,
                                                  input_parallel,
                                                  bias=bias_)
790
        if self.reduce_results and self.tp_size > 1:
791
            output = tensor_model_parallel_all_reduce(output_parallel)
792
        else:
793
794
795
            output = output_parallel

        output_bias = self.bias if self.skip_bias_add else None
796
797

        return output, output_bias
798
799
800
801
802
803
804
805

    def extra_repr(self) -> str:
        s = f"input_features={self.input_size_per_partition}"
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={self.tp_size}"
        s += f", reduce_results={self.reduce_results}"
        return s