linear.py 45.8 KB
Newer Older
1
from abc import abstractmethod
2
from typing import Dict, List, Optional, Tuple
3
4
5

import torch
import torch.nn.functional as F
6
from torch.nn.parameter import Parameter, UninitializedParameter
7

8
9
10
11
12
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size,
                              split_tensor_along_last_dim,
                              tensor_model_parallel_all_gather,
                              tensor_model_parallel_all_reduce)
13
from vllm.logger import init_logger
14
15
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig, QuantizeMethodBase)
16
from vllm.model_executor.parameter import (BasevLLMParameter,
17
18
                                           PackedvLLMParameter,
                                           PerTensorScaleParameter)
19
20
21
22
from vllm.model_executor.utils import set_weight_attrs

logger = init_logger(__name__)

23
WEIGHT_LOADER_V2_SUPPORTED = [
24
25
    "CompressedTensorsLinearMethod", "AWQMarlinLinearMethod",
    "AWQLinearMethod", "GPTQMarlinLinearMethod"
26
]
27

28

29
30
31
32
33
34
35
36
def adjust_marlin_shard(param, shard_size, shard_offset):
    marlin_tile_size = getattr(param, "marlin_tile_size", None)
    if marlin_tile_size is None:
        return shard_size, shard_offset

    return shard_size * marlin_tile_size, shard_offset * marlin_tile_size


37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
def adjust_bitsandbytes_shard(param: Parameter,
                              qkv_offsets: Dict[str, Tuple[int, int]],
                              loaded_shard_id: str) -> Tuple[int, int]:
    """Adjust the quantization offsets and sizes for BitsAndBytes sharding."""

    total, _ = qkv_offsets["total"]
    orig_offset, orig_size = qkv_offsets[loaded_shard_id]

    quantized_total = param.data.shape[0]
    quantized_offset = orig_offset * quantized_total // total
    quantized_size = orig_size * quantized_total // total

    return quantized_size, quantized_offset


52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
    """For fused modules (QKV and MLP) we have an array of length
    N that holds 1 scale for each "logical" matrix. So the param
    is an array of length N. The loaded_weight corresponds to 
    one of the shards on disk. Here, we slice the param based on 
    the shard_id for loading.
    """
    qkv_idxs = {"q": 0, "k": 1, "v": 2}

    if isinstance(shard_id, str):
        shard_id = qkv_idxs[shard_id]
    elif not isinstance(shard_id, int):
        raise ValueError(f"Unknown Shard Id {shard_id}")

    # AutoFP8 scales do not have a shape
    # compressed-tensors scales do have a shape
    if len(loaded_weight.shape) != 0:
        assert loaded_weight.shape[0] == 1
        loaded_weight = loaded_weight[0]

    return param[shard_id], loaded_weight


75
class LinearMethodBase(QuantizeMethodBase):
76
77
78
    """Base class for different (maybe quantized) linear methods."""

    @abstractmethod
79
80
    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
James Fleming's avatar
James Fleming committed
81
                       output_partition_sizes: List[int], input_size: int,
82
83
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
84
85
        """Create weights for a linear layer. 
           The weights will be set as attributes of the layer.
86

87
88
89
90
91
92
93
94
95
96
        Args:
            layer: The layer that is using the LinearMethodBase factory.
            input_size_per_partition: Size of the weight input dim on rank X.
            output_partition_sizes: Sizes of the output dim of each logical 
                weight on rank X. E.g., output_partition_sizes for QKVLinear
                is a list contains the width of Wq, Wk, Wv on rank X.
            input_size: Size of the input dim of the weight across all ranks.
            output_size: Size of the output dim of the weight across all ranks.
            params_dtype: Datatype of the parameters.
        """
97
98
99
        raise NotImplementedError

    @abstractmethod
100
101
102
103
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
104
105
        """Apply the weights in layer to the input tensor.
        Expects create_weights to have been called before on the layer."""
106
107
108
109
        raise NotImplementedError


class UnquantizedLinearMethod(LinearMethodBase):
110
    """Linear method without quantization."""
111

112
113
    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
James Fleming's avatar
James Fleming committed
114
                       output_partition_sizes: List[int], input_size: int,
115
116
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
117
        weight = Parameter(torch.empty(sum(output_partition_sizes),
CHU Tianxiang's avatar
CHU Tianxiang committed
118
                                       input_size_per_partition,
119
120
121
                                       dtype=params_dtype),
                           requires_grad=False)
        set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
122
123
        layer.register_parameter("weight", weight)
        set_weight_attrs(weight, extra_weight_attrs)
124

125
126
127
128
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
129
130

        return F.linear(x, layer.weight, bias)
131
132


133
134
class LinearBase(torch.nn.Module):
    """Base linear layer.
135
136
137
138
139
140
141

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        bias: If true, add bias.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
142
        quant_config: Quantization configure.
143
144
145
146
147
148
149
150
    """

    def __init__(
        self,
        input_size: int,
        output_size: int,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
151
        quant_config: Optional[QuantizationConfig] = None,
152
        prefix: str = "",
153
154
155
156
157
158
159
160
161
162
    ):
        super().__init__()

        # Keep input parameters
        self.input_size = input_size
        self.output_size = output_size
        self.skip_bias_add = skip_bias_add
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype
163
        if quant_config is None:
164
165
            self.quant_method: Optional[
                QuantizeMethodBase] = UnquantizedLinearMethod()
166
        else:
167
168
            self.quant_method = quant_config.get_quant_method(self,
                                                              prefix=prefix)
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError


class ReplicatedLinear(LinearBase):
    """Replicated linear layer.

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        bias: If true, add bias.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
        quant_config: Quantization configure.
184
185
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
186
187
    """

188
189
190
191
192
193
    def __init__(self,
                 input_size: int,
                 output_size: int,
                 bias: bool = True,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
194
                 quant_config: Optional[QuantizationConfig] = None,
195
196
197
198
199
200
201
                 prefix: str = ""):
        super().__init__(input_size,
                         output_size,
                         skip_bias_add,
                         params_dtype,
                         quant_config,
                         prefix=prefix)
202

203
204
        # All the linear layer supports quant method.
        assert self.quant_method is not None
205
206
207
208
209
        self.quant_method.create_weights(self,
                                         self.input_size, [self.output_size],
                                         self.input_size,
                                         self.output_size,
                                         self.params_dtype,
210
                                         weight_loader=self.weight_loader,
211
                                         prefix=prefix)
212

213
214
        if bias:
            self.bias = Parameter(
215
                torch.empty(self.output_size, dtype=self.params_dtype))
216
217
218
219
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
220
221
222
        else:
            self.register_parameter("bias", None)

223
224
225
226
227
228
229
230
231
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        # If the weight on disk does not have a shape, give it one
        # (such scales for AutoFp8).
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)

        assert param.size() == loaded_weight.size()
        param.data.copy_(loaded_weight)

232
233
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        bias = self.bias if not self.skip_bias_add else None
234
        assert self.quant_method is not None
235
        output = self.quant_method.apply(self, x, bias)
236
237
238
        output_bias = self.bias if self.skip_bias_add else None
        return output, output_bias

239
240
241
242
243
244
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        return s

245

246
class ColumnParallelLinear(LinearBase):
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
    """Linear layer with column parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its second dimension as A = [A_1, ..., A_p].

    Args:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y_i = XA_i
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
263
        quant_config: Quantization configure.
James Fleming's avatar
James Fleming committed
264
265
        output_sizes: list of output sizes packed into one output, like for QKV
                       the list would be size 3.
266
267
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj) 
268
269
    """

270
271
272
273
274
275
276
277
    def __init__(self,
                 input_size: int,
                 output_size: int,
                 bias: bool = True,
                 gather_output: bool = False,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
                 quant_config: Optional[QuantizationConfig] = None,
278
                 output_sizes: Optional[List[int]] = None,
279
                 prefix: str = ""):
280
        super().__init__(input_size, output_size, skip_bias_add, params_dtype,
281
                         quant_config, prefix)
282
283

        self.gather_output = gather_output
284

285
286
        # Divide the weight matrix along the last dimension.
        tp_size = get_tensor_model_parallel_world_size()
287
288
289
290
291
292
293
294
295
296
        assert self.quant_method is not None
        self.output_size_per_partition = divide(self.output_size, tp_size)
        self.output_partition_sizes = [self.output_size_per_partition]
        # If QKV or MergedColumn, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = [
                divide(output_size, tp_size)
                for output_size in self.output_sizes
            ]

James Fleming's avatar
James Fleming committed
297
298
        if output_sizes is None:
            output_sizes = [output_size]
299

300
301
302
303
304
305
306
        self.quant_method.create_weights(
            layer=self,
            input_size_per_partition=self.input_size,
            output_partition_sizes=self.output_partition_sizes,
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
307
308
309
            weight_loader=(
                self.weight_loader_v2 if self.quant_method.__class__.__name__
                in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader),
310
            prefix=prefix)
311
312
313
314
315
316
317
318
319
320
321
322
323
324
        if bias:
            self.bias = Parameter(
                torch.empty(self.output_size_per_partition,
                            dtype=params_dtype))
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
        else:
            self.register_parameter("bias", None)

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        tp_rank = get_tensor_model_parallel_rank()
        output_dim = getattr(param, "output_dim", None)
325
326
327
328
329
330
331
332
333
334
335

        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
            param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)

336
337
338
339
340
341
        param_data = param.data
        if output_dim is not None:
            shard_size = param_data.shape[output_dim]
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                 shard_size)
342
343
344
345
346

        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)
347

348
349
350
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

351
352
353
    def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor):
        param.load_column_parallel_weight(loaded_weight=loaded_weight)

354
355
356
357
    def forward(self, input_):
        bias = self.bias if not self.skip_bias_add else None

        # Matrix multiply.
358
        assert self.quant_method is not None
359
        output_parallel = self.quant_method.apply(self, input_, bias)
360
361
362
363
364
365
366
367
        if self.gather_output:
            # All-gather across the partitions.
            output = tensor_model_parallel_all_gather(output_parallel)
        else:
            output = output_parallel
        output_bias = self.bias if self.skip_bias_add else None
        return output, output_bias

368
369
370
371
372
373
374
375
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size_per_partition}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={get_tensor_model_parallel_world_size()}"
        s += f", gather_output={self.gather_output}"
        return s

376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394

class MergedColumnParallelLinear(ColumnParallelLinear):
    """Packed linear layers with column parallelism.

    Similar to ColumnParallelLinear, but the weight matrix is concatenated
    along the output dimension. When the weight matrix is loaded, the
    different partitions are sharded separately.

    Args:
        input_size: input dimension of the linear layer.
        output_sizes: list of output dimensions of the linear layer.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make the output
                       available to all GPUs, otherwise, every GPU will have
                       its own output.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
395
        quant_config: Quantization configure.
396
397
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
398
399
    """

400
401
402
403
404
405
406
    def __init__(self,
                 input_size: int,
                 output_sizes: List[int],
                 bias: bool = True,
                 gather_output: bool = False,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
407
                 quant_config: Optional[QuantizationConfig] = None,
408
                 prefix: str = ""):
409
410
411
        self.output_sizes = output_sizes
        tp_size = get_tensor_model_parallel_world_size()
        assert all(output_size % tp_size == 0 for output_size in output_sizes)
412
413
414
415
416
417
        super().__init__(input_size=input_size,
                         output_size=sum(output_sizes),
                         bias=bias,
                         gather_output=gather_output,
                         skip_bias_add=skip_bias_add,
                         params_dtype=params_dtype,
418
419
                         quant_config=quant_config,
                         prefix=prefix)
420
421
422
423
424

    def weight_loader(self,
                      param: Parameter,
                      loaded_weight: torch.Tensor,
                      loaded_shard_id: Optional[int] = None):
James Fleming's avatar
James Fleming committed
425

426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.data[loaded_shard_id].copy_(loaded_weight)
            param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            return

        if is_gguf_weight and isinstance(param, UninitializedParameter):
            from gguf.constants import GGML_QUANT_SIZES

            ori_shape = param.tensor_shape
            weight_types = self.qweight_type.shard_weight_type.values()
            row_size = []
            for weight_type in weight_types:
                block_size, type_size = GGML_QUANT_SIZES[weight_type]
                row_size.append(ori_shape[1] // block_size * type_size)
            q_shape = (ori_shape[0], max(row_size))
            param.materialize(q_shape, dtype=loaded_weight.dtype)

447
448
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
449
        # Special case for AQLM codebooks.
James Fleming's avatar
James Fleming committed
450
        is_metadata = getattr(param, "is_metadata", False)
451
452
        # Special case for per-tensor scale to load scalar into fused array.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
453

454
        if loaded_shard_id is None:
455
            # Loaded weight is already fused on disk (qkv/mlp).
456
            if output_dim is None:
457
                if needs_scalar_to_array:
458
459
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
                        param_data, loaded_weight, 0)
460

461
462
463
464
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            current_shard_offset = 0
465
            shard_offsets: List[Tuple[int, int, int]] = []
466
467
468
469
470
            for i, output_size in enumerate(self.output_sizes):
                shard_offsets.append((i, current_shard_offset, output_size))
                current_shard_offset += output_size
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
471
                # Special case for Quantization.
472
473
474
475
476
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
                    shard_size = shard_size // param.pack_factor
                    shard_offset = shard_offset // param.pack_factor
477
                    # Special case for Marlin.
478
479
480
                    shard_size, shard_offset = adjust_marlin_shard(
                        param, shard_size, shard_offset)

481
482
483
484
485
486
487
488
489
490
491
                loaded_weight_shard = loaded_weight.narrow(
                    output_dim, shard_offset, shard_size)
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id < len(self.output_sizes)
        tp_rank = get_tensor_model_parallel_rank()
        tp_size = get_tensor_model_parallel_world_size()
        if output_dim is not None:
            shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
            shard_size = self.output_sizes[loaded_shard_id] // tp_size
492
            # Special case for quantization.
493
494
495
496
497
498
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
                shard_size = shard_size // param.pack_factor
                shard_offset = shard_offset // param.pack_factor
499
                # Special case for Marlin.
500
501
502
                shard_size, shard_offset = adjust_marlin_shard(
                    param, shard_size, shard_offset)

503
504
505
506
507
508
            use_bitsandbytes = getattr(param, "use_bitsandbytes", False)
            if use_bitsandbytes:
                shard_size = loaded_weight.shape[output_dim]
                shard_offset = loaded_weight.shape[output_dim] * \
                    loaded_shard_id

509
            if is_gguf_weight:
510
511
512
513
                tp_size = get_tensor_model_parallel_world_size()
                output_dim = getattr(param, "output_dim", None)
                shard_shape = list(loaded_weight.shape)
                shard_shape[output_dim] = shard_shape[output_dim] // tp_size
514
                param.shard_id.append(loaded_shard_id)
515
516
517
518
519
                param.shard_size[loaded_shard_id] = shard_shape

                input_dim = getattr(param, "input_dim", None)
                input_size = loaded_weight.shape[input_dim]
                param_data = param_data.narrow(input_dim, 0, input_size)
520

521
522
523
524
525
            param_data = param_data.narrow(output_dim, shard_offset,
                                           shard_size)
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                 shard_size)
526
        # Special case for AQLM codebooks.
James Fleming's avatar
James Fleming committed
527
528
529
530
531
        elif is_metadata:
            # metadata indicates fixed size concatenated along dim 0
            shard_size = loaded_weight.shape[0]
            shard_offset = loaded_shard_id * shard_size
            param_data = param_data.narrow(0, shard_offset, shard_size)
532

533
534
535
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
536
537
                param_data, loaded_weight, loaded_shard_id)

538
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
539
540
541
542
543
544
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "MergedColumnParallelLinear, assume the weight is "
                    "the same for all partitions.")
545

546
547
548
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
    def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
                                           loaded_weight: torch.Tensor):
        """
        Handle special case for models where MLP layers are already
        fused on disk. In this case, we have no shard id. This function
        determmines the shard id by splitting these layers and then calls
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """

        current_shard_offset = 0
        shard_offsets: List[Tuple[int, int, int]] = []
        for i, output_size in enumerate(self.output_sizes):
            shard_offsets.append((i, current_shard_offset, output_size))
            current_shard_offset += output_size

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            if isinstance(param, PackedvLLMParameter
                          ) and param.packed_dim == param.output_dim:
                param.adjust_shard_indexes_for_packing(
                    shard_size=shard_size, shard_offset=shard_offset)

            loaded_weight_shard = loaded_weight.narrow(param.output_dim,
                                                       shard_offset,
                                                       shard_size)
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

    def weight_loader_v2(self,
                         param: BasevLLMParameter,
                         loaded_weight: torch.Tensor,
                         loaded_shard_id: Optional[int] = None):
        if loaded_shard_id is None:
586
587
588
589
590
591
            if isinstance(param, PerTensorScaleParameter):
                param.load_merged_column_weight(loaded_weight=loaded_weight,
                                                shard_id=0)
                return
            elif type(param) is BasevLLMParameter:
                param.load_merged_column_weight(loaded_weight=loaded_weight)
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
                return
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id < len(self.output_sizes)

        tp_size = get_tensor_model_parallel_world_size()
        shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
        shard_size = self.output_sizes[loaded_shard_id] // tp_size

        param.load_merged_column_weight(loaded_weight=loaded_weight,
                                        shard_id=loaded_shard_id,
                                        shard_offset=shard_offset,
                                        shard_size=shard_size)

607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628

class QKVParallelLinear(ColumnParallelLinear):
    """Linear layers for the attention's QKV transformation.

    Linear layers for the linear transformation of the query, key, and value
    vectors in the attention layer. The weight matrix is concatenated along
    the output dimension. The layer is parallelized along the head dimension.
    When the number of key/value heads is smaller than the number of query
    heads (e.g., multi-query/grouped-query attention), the key/value head may
    be replicated while the query heads are partitioned.

    Args:
        hidden_size: input hidden state size of the transformer.
        head_size: size of each attention head.
        total_num_heads: total number of attention query heads.
        total_num_kv_heads: total number of attention key/value heads. If
                            None, assume total_num_kv_heads = total_num_heads.
        bias: If true, add bias.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
629
        quant_config: Quantization configure.
630
631
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
632
633
    """

634
635
636
637
638
639
640
641
    def __init__(self,
                 hidden_size: int,
                 head_size: int,
                 total_num_heads: int,
                 total_num_kv_heads: Optional[int] = None,
                 bias: bool = True,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
642
                 quant_config: Optional[QuantizationConfig] = None,
643
                 prefix: str = ""):
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
        self.hidden_size = hidden_size
        self.head_size = head_size
        self.total_num_heads = total_num_heads
        if total_num_kv_heads is None:
            total_num_kv_heads = total_num_heads
        self.total_num_kv_heads = total_num_kv_heads
        # Divide the weight matrix along the last dimension.
        tp_size = get_tensor_model_parallel_world_size()
        self.num_heads = divide(self.total_num_heads, tp_size)
        if tp_size >= self.total_num_kv_heads:
            self.num_kv_heads = 1
            self.num_kv_head_replicas = divide(tp_size,
                                               self.total_num_kv_heads)
        else:
            self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
            self.num_kv_head_replicas = 1
        input_size = self.hidden_size
        output_size = (self.num_heads +
                       2 * self.num_kv_heads) * tp_size * self.head_size
663
664
665
666
        self.output_sizes = [
            self.num_heads * self.head_size * tp_size,  # q_proj
            self.num_kv_heads * self.head_size * tp_size,  # k_proj
            self.num_kv_heads * self.head_size * tp_size,  # v_proj 
James Fleming's avatar
James Fleming committed
667
668
        ]

669
670
671
672
673
674
        super().__init__(input_size=input_size,
                         output_size=output_size,
                         bias=bias,
                         gather_output=False,
                         skip_bias_add=skip_bias_add,
                         params_dtype=params_dtype,
675
676
                         quant_config=quant_config,
                         prefix=prefix)
677

678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
    def _get_shard_offset_mapping(self, loaded_shard_id: str):
        shard_offset_mapping = {
            "q": 0,
            "k": self.num_heads * self.head_size,
            "v": (self.num_heads + self.num_kv_heads) * self.head_size,
            "total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size
        }
        return shard_offset_mapping.get(loaded_shard_id)

    def _get_shard_size_mapping(self, loaded_shard_id: str):
        shard_size_mapping = {
            "q": self.num_heads * self.head_size,
            "k": self.num_kv_heads * self.head_size,
            "v": self.num_kv_heads * self.head_size,
        }
        return shard_size_mapping.get(loaded_shard_id)

    def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
                                           loaded_weight: torch.Tensor):
        """
        Handle special case for models where QKV layers are already 
        fused on disk. In this case, we have no shard id. This function
        determmines the shard id by splitting these layers and then calls
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """
        shard_offsets = [
            # (shard_id, shard_offset, shard_size)
            ("q", 0, self.total_num_heads * self.head_size),
            ("k", self.total_num_heads * self.head_size,
             self.total_num_kv_heads * self.head_size),
            ("v",
             (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
             self.total_num_kv_heads * self.head_size),
        ]

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            if isinstance(param, PackedvLLMParameter
                          ) and param.packed_dim == param.output_dim:
                param.adjust_shard_indexes_for_packing(
                    shard_size=shard_size, shard_offset=shard_offset)

            loaded_weight_shard = loaded_weight.narrow(param.output_dim,
                                                       shard_offset,
                                                       shard_size)
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

    def weight_loader_v2(self,
                         param: BasevLLMParameter,
                         loaded_weight: torch.Tensor,
                         loaded_shard_id: Optional[str] = None):
        if loaded_shard_id is None:  # special case for certain models
735
736
737
738
739
740
            if isinstance(param, PerTensorScaleParameter):
                param.load_merged_column_weight(loaded_weight=loaded_weight,
                                                shard_id=0)
                return
            elif type(param) is BasevLLMParameter:
                param.load_merged_column_weight(loaded_weight=loaded_weight)
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
                return
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id in ["q", "k", "v"]

        shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
        shard_size = self._get_shard_size_mapping(loaded_shard_id)

        param.load_qkv_weight(loaded_weight=loaded_weight,
                              num_heads=self.num_kv_head_replicas,
                              shard_id=loaded_shard_id,
                              shard_offset=shard_offset,
                              shard_size=shard_size)

756
757
758
759
    def weight_loader(self,
                      param: Parameter,
                      loaded_weight: torch.Tensor,
                      loaded_shard_id: Optional[str] = None):
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782

        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type and loaded_shard_id is not None:
            idx_map = {"q": 0, "k": 1, "v": 2}
            param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
            param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            return

        if is_gguf_weight and isinstance(param, UninitializedParameter):
            from gguf.constants import GGML_QUANT_SIZES

            ori_shape = param.tensor_shape
            weight_types = self.qweight_type.shard_weight_type.values()
            row_size = []
            for weight_type in weight_types:
                block_size, type_size = GGML_QUANT_SIZES[weight_type]
                row_size.append(ori_shape[1] // block_size * type_size)
            q_shape = (ori_shape[0], max(row_size))
            param.materialize(q_shape, dtype=loaded_weight.dtype)

783
784
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
785
        # Special case for AQLM codebooks.
James Fleming's avatar
James Fleming committed
786
        is_metadata = getattr(param, "is_metadata", False)
787

788
789
        # Special case for per-tensor scales in fused case.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
790

791
        if loaded_shard_id is None:
792
            # Loaded weight is already fused on disk (qkv/mlp).
793
            if output_dim is None:
794
                if needs_scalar_to_array:
795
796
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
                        param_data, loaded_weight, 0)
797

798
799
800
801
802
803
804
805
806
807
808
809
810
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            shard_offsets = [
                # (shard_id, shard_offset, shard_size)
                ("q", 0, self.total_num_heads * self.head_size),
                ("k", self.total_num_heads * self.head_size,
                 self.total_num_kv_heads * self.head_size),
                ("v", (self.total_num_heads + self.total_num_kv_heads) *
                 self.head_size, self.total_num_kv_heads * self.head_size),
            ]
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
811
                # Special case for Quantized Weights.
812
813
814
815
816
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
                    shard_size = shard_size // param.pack_factor
                    shard_offset = shard_offset // param.pack_factor
817

818
                    # Special case for Marlin.
819
820
821
                    shard_size, shard_offset = adjust_marlin_shard(
                        param, shard_size, shard_offset)

822
823
824
825
826
827
828
                loaded_weight_shard = loaded_weight.narrow(
                    output_dim, shard_offset, shard_size)
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        tp_rank = get_tensor_model_parallel_rank()
        assert loaded_shard_id in ["q", "k", "v"]
829
830

        # If output dim is defined, use the default loading process.
831
832
833
834
835
836
837
838
839
840
841
        if output_dim is not None:
            if loaded_shard_id == "q":
                shard_offset = 0
                shard_size = self.num_heads * self.head_size
            elif loaded_shard_id == "k":
                shard_offset = self.num_heads * self.head_size
                shard_size = self.num_kv_heads * self.head_size
            elif loaded_shard_id == "v":
                shard_offset = (self.num_heads +
                                self.num_kv_heads) * self.head_size
                shard_size = self.num_kv_heads * self.head_size
842
            # Special case for Quantized Weights.
843
844
845
846
847
848
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
                shard_size = shard_size // param.pack_factor
                shard_offset = shard_offset // param.pack_factor
849

850
                # Special case for Marlin.
851
852
853
                shard_size, shard_offset = adjust_marlin_shard(
                    param, shard_size, shard_offset)

854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
            use_bitsandbytes = getattr(param, "use_bitsandbytes", False)
            if use_bitsandbytes:
                orig_qkv_offsets = {
                    "q": (0, self.num_heads * self.head_size),
                    "k": (self.num_heads * self.head_size,
                          self.num_kv_heads * self.head_size),
                    "v":
                    ((self.num_heads + self.num_kv_heads) * self.head_size,
                     self.num_kv_heads * self.head_size),
                    "total":
                    ((self.num_heads + 2 * self.num_kv_heads) * self.head_size,
                     0)
                }
                shard_size, shard_offset = adjust_bitsandbytes_shard(
                    param, orig_qkv_offsets, loaded_shard_id)

870
            if is_gguf_weight:
871
872
873
874
                tp_size = get_tensor_model_parallel_world_size()
                output_dim = getattr(param, "output_dim", None)
                shard_shape = list(loaded_weight.shape)
                shard_shape[output_dim] = shard_shape[output_dim] // tp_size
875
                param.shard_id.append(loaded_shard_id)
876
877
                param.shard_size[loaded_shard_id] = shard_shape

878
879
880
881
                input_dim = getattr(param, "input_dim", None)
                input_size = loaded_weight.shape[input_dim]
                param_data = param_data.narrow(input_dim, 0, input_size)

882
883
            param_data = param_data.narrow(output_dim, shard_offset,
                                           shard_size)
884
885
886
887
            if loaded_shard_id == "q":
                shard_id = tp_rank
            else:
                shard_id = tp_rank // self.num_kv_head_replicas
888
889
890
            start_idx = shard_id * shard_size
            loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                 shard_size)
891
        # Special case for for AQLM codebooks.
James Fleming's avatar
James Fleming committed
892
893
894
895
896
897
        elif is_metadata:
            # metadata indicates fixed size concatenated along dim 0
            shard_size = loaded_weight.shape[0]
            shard_index = ["q", "k", "v"].index(loaded_shard_id)
            param_data = param_data.narrow(0, shard_index * shard_size,
                                           shard_size)
898
899
900
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
901
                param_data, loaded_weight, loaded_shard_id)
902
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
903
904
905
906
907
908
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "QKVParallelLinear, assume the weight is the same "
                    "for all partitions.")
909

910
911
912
913
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)


914
class RowParallelLinear(LinearBase):
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
    """Linear layer with row parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its first dimension and X along its second dimension as:
               -   -
              | A_1 |
              | .   |
          A = | .   |        X = [X_1, ..., X_p]
              | .   |
              | A_p |
               -   -
    Arguments:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias. Note that bias is not parallelized.
        input_is_parallel: If true, we assume that the input is already
                           split across the GPUs and we do not split
                           again.
        skip_bias_add: This was added to enable performance optimization where
                       bias can be fused with other element-wise operations.
                       We skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
937
        quant_config: Quantization configure.
938
939
    """

940
941
942
943
944
945
946
947
    def __init__(self,
                 input_size: int,
                 output_size: int,
                 bias: bool = True,
                 input_is_parallel: bool = True,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
                 reduce_results: bool = True,
948
                 quant_config: Optional[QuantizationConfig] = None,
949
                 prefix: str = ""):
950
        super().__init__(input_size, output_size, skip_bias_add, params_dtype,
951
                         quant_config, prefix)
952

953
954
955
956
        self.input_is_parallel = input_is_parallel
        self.reduce_results = reduce_results

        # Divide the weight matrix along the last dimension.
957
        self.tp_rank = get_tensor_model_parallel_rank()
958
959
        self.tp_size = get_tensor_model_parallel_world_size()
        self.input_size_per_partition = divide(input_size, self.tp_size)
960
        assert self.quant_method is not None
961

962
963
964
965
966
967
968
        self.quant_method.create_weights(
            layer=self,
            input_size_per_partition=self.input_size_per_partition,
            output_partition_sizes=[self.output_size],
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
969
970
971
            weight_loader=(
                self.weight_loader_v2 if self.quant_method.__class__.__name__
                in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader),
972
            prefix=prefix)
973
974
975
976
977
978
        if not reduce_results and (bias and not skip_bias_add):
            raise ValueError("When not reduce the results, adding bias to the "
                             "results can lead to incorrect results")

        if bias:
            self.bias = Parameter(
979
                torch.empty(self.output_size, dtype=params_dtype))
980
981
982
983
984
985
986
987
988
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
        else:
            self.register_parameter("bias", None)

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        tp_rank = get_tensor_model_parallel_rank()
989
        tp_size = get_tensor_model_parallel_world_size()
990
        input_dim = getattr(param, "input_dim", None)
991
992
993
994
995
996
997
998
999

        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
1000
1001
1002
1003
            weight_shape = list(loaded_weight.shape)
            if input_dim:
                weight_shape[input_dim] = weight_shape[input_dim] // tp_size
            param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
1004

1005
1006
1007
1008
1009
1010
        param_data = param.data
        if input_dim is not None:
            shard_size = param_data.shape[input_dim]
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(input_dim, start_idx,
                                                 shard_size)
1011

1012
1013
1014
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
1015
1016
            loaded_weight = loaded_weight.reshape(1)

1017
1018
1019
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

1020
1021
1022
1023
    def weight_loader_v2(self, param: BasevLLMParameter,
                         loaded_weight: torch.Tensor):
        param.load_row_parallel_weight(loaded_weight=loaded_weight)

1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
    def forward(self, input_):
        if self.input_is_parallel:
            input_parallel = input_
        else:
            tp_rank = get_tensor_model_parallel_rank()
            splitted_input = split_tensor_along_last_dim(
                input_, num_partitions=self.tp_size)
            input_parallel = splitted_input[tp_rank].contiguous()

        # Matrix multiply.
1034
        assert self.quant_method is not None
1035
1036
1037
1038
1039
1040
        # Only fuse bias add into GEMM for rank 0 (this ensures that
        # bias will not get added more than once in TP>1 case)
        bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
        output_parallel = self.quant_method.apply(self,
                                                  input_parallel,
                                                  bias=bias_)
1041
        if self.reduce_results and self.tp_size > 1:
1042
            output = tensor_model_parallel_all_reduce(output_parallel)
1043
        else:
1044
1045
1046
            output = output_parallel

        output_bias = self.bias if self.skip_bias_add else None
1047
1048

        return output, output_bias
1049
1050
1051
1052
1053
1054
1055
1056

    def extra_repr(self) -> str:
        s = f"input_features={self.input_size_per_partition}"
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={self.tp_size}"
        s += f", reduce_results={self.reduce_results}"
        return s