linear.py 50 KB
Newer Older
1
import itertools
2
from abc import abstractmethod
3
from typing import Dict, List, Optional, Tuple
4
5
6

import torch
import torch.nn.functional as F
7
from torch.nn.parameter import Parameter, UninitializedParameter
8

9
10
11
12
13
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size,
                              split_tensor_along_last_dim,
                              tensor_model_parallel_all_gather,
                              tensor_model_parallel_all_reduce)
14
from vllm.logger import init_logger
15
16
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig, QuantizeMethodBase)
17
# yapf: disable
18
from vllm.model_executor.parameter import (BasevLLMParameter,
19
                                           BlockQuantScaleParameter,
20
                                           PackedColumnParameter,
21
                                           PackedvLLMParameter,
22
23
                                           PerTensorScaleParameter,
                                           RowvLLMParameter)
24
# yapf: enable
25
26
27
28
from vllm.model_executor.utils import set_weight_attrs

logger = init_logger(__name__)

29
WEIGHT_LOADER_V2_SUPPORTED = [
30
    "CompressedTensorsLinearMethod", "AWQMarlinLinearMethod",
31
    "AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod",
32
    "MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod",
33
    "TPUInt8LinearMethod", "GPTQLinearMethod", "FBGEMMFp8LinearMethod",
34
    "ModelOptFp8LinearMethod", "IPEXAWQLinearMethod", "IPEXGPTQLinearMethod",
35
    "HQQMarlinMethod", "QuarkLinearMethod"
36
]
37

38

39
40
41
42
43
44
45
46
def adjust_marlin_shard(param, shard_size, shard_offset):
    marlin_tile_size = getattr(param, "marlin_tile_size", None)
    if marlin_tile_size is None:
        return shard_size, shard_offset

    return shard_size * marlin_tile_size, shard_offset * marlin_tile_size


47
def adjust_bitsandbytes_4bit_shard(param: Parameter,
48
                                   shard_offsets: Dict[str, Tuple[int, int]],
49
                                   loaded_shard_id: str) -> Tuple[int, int]:
50
51
    """Adjust the quantization offsets and sizes for BitsAndBytes sharding."""

52
53
    total, _ = shard_offsets["total"]
    orig_offset, orig_size = shard_offsets[loaded_shard_id]
54
55
56
57
58
59
60
61

    quantized_total = param.data.shape[0]
    quantized_offset = orig_offset * quantized_total // total
    quantized_size = orig_size * quantized_total // total

    return quantized_size, quantized_offset


62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
    """For fused modules (QKV and MLP) we have an array of length
    N that holds 1 scale for each "logical" matrix. So the param
    is an array of length N. The loaded_weight corresponds to 
    one of the shards on disk. Here, we slice the param based on 
    the shard_id for loading.
    """
    qkv_idxs = {"q": 0, "k": 1, "v": 2}

    if isinstance(shard_id, str):
        shard_id = qkv_idxs[shard_id]
    elif not isinstance(shard_id, int):
        raise ValueError(f"Unknown Shard Id {shard_id}")

    # AutoFP8 scales do not have a shape
    # compressed-tensors scales do have a shape
    if len(loaded_weight.shape) != 0:
        assert loaded_weight.shape[0] == 1
        loaded_weight = loaded_weight[0]

    return param[shard_id], loaded_weight


85
class LinearMethodBase(QuantizeMethodBase):
86
87
88
    """Base class for different (maybe quantized) linear methods."""

    @abstractmethod
89
90
    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
James Fleming's avatar
James Fleming committed
91
                       output_partition_sizes: List[int], input_size: int,
92
93
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
94
95
        """Create weights for a linear layer. 
           The weights will be set as attributes of the layer.
96

97
98
99
100
101
102
103
104
105
106
        Args:
            layer: The layer that is using the LinearMethodBase factory.
            input_size_per_partition: Size of the weight input dim on rank X.
            output_partition_sizes: Sizes of the output dim of each logical 
                weight on rank X. E.g., output_partition_sizes for QKVLinear
                is a list contains the width of Wq, Wk, Wv on rank X.
            input_size: Size of the input dim of the weight across all ranks.
            output_size: Size of the output dim of the weight across all ranks.
            params_dtype: Datatype of the parameters.
        """
107
108
109
        raise NotImplementedError

    @abstractmethod
110
111
112
113
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
114
115
        """Apply the weights in layer to the input tensor.
        Expects create_weights to have been called before on the layer."""
116
117
118
119
        raise NotImplementedError


class UnquantizedLinearMethod(LinearMethodBase):
120
    """Linear method without quantization."""
121

122
123
    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
James Fleming's avatar
James Fleming committed
124
                       output_partition_sizes: List[int], input_size: int,
125
126
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
127
        weight = Parameter(torch.empty(sum(output_partition_sizes),
CHU Tianxiang's avatar
CHU Tianxiang committed
128
                                       input_size_per_partition,
129
130
131
                                       dtype=params_dtype),
                           requires_grad=False)
        set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
132
133
        layer.register_parameter("weight", weight)
        set_weight_attrs(weight, extra_weight_attrs)
134

135
136
137
138
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
139
140

        return F.linear(x, layer.weight, bias)
141
142


143
144
class LinearBase(torch.nn.Module):
    """Base linear layer.
145
146
147
148
149
150
151

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        bias: If true, add bias.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
152
        quant_config: Quantization configure.
153
154
155
156
157
158
159
160
    """

    def __init__(
        self,
        input_size: int,
        output_size: int,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
161
        quant_config: Optional[QuantizationConfig] = None,
162
        prefix: str = "",
163
164
165
166
167
168
169
170
171
172
    ):
        super().__init__()

        # Keep input parameters
        self.input_size = input_size
        self.output_size = output_size
        self.skip_bias_add = skip_bias_add
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype
173
        if quant_config is None:
174
175
            self.quant_method: Optional[
                QuantizeMethodBase] = UnquantizedLinearMethod()
176
        else:
177
178
            self.quant_method = quant_config.get_quant_method(self,
                                                              prefix=prefix)
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError


class ReplicatedLinear(LinearBase):
    """Replicated linear layer.

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        bias: If true, add bias.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
        quant_config: Quantization configure.
194
195
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
196
197
    """

198
199
200
201
202
203
    def __init__(self,
                 input_size: int,
                 output_size: int,
                 bias: bool = True,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
204
                 quant_config: Optional[QuantizationConfig] = None,
205
206
207
208
209
210
211
                 prefix: str = ""):
        super().__init__(input_size,
                         output_size,
                         skip_bias_add,
                         params_dtype,
                         quant_config,
                         prefix=prefix)
212

213
214
        # All the linear layer supports quant method.
        assert self.quant_method is not None
215
216
217
218
219
        self.quant_method.create_weights(self,
                                         self.input_size, [self.output_size],
                                         self.input_size,
                                         self.output_size,
                                         self.params_dtype,
220
                                         weight_loader=self.weight_loader)
221

222
223
        if bias:
            self.bias = Parameter(
224
                torch.empty(self.output_size, dtype=self.params_dtype))
225
226
227
228
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
229
230
231
        else:
            self.register_parameter("bias", None)

232
233
234
235
236
237
238
239
240
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        # If the weight on disk does not have a shape, give it one
        # (such scales for AutoFp8).
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)

        assert param.size() == loaded_weight.size()
        param.data.copy_(loaded_weight)

241
242
243
    def forward(
        self, x: torch.Tensor
    ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
244
        bias = self.bias if not self.skip_bias_add else None
245
        assert self.quant_method is not None
246
        output = self.quant_method.apply(self, x, bias)
247
248
249
        output_bias = self.bias if self.skip_bias_add else None
        return output, output_bias

250
251
252
253
254
255
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        return s

256

257
class ColumnParallelLinear(LinearBase):
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
    """Linear layer with column parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its second dimension as A = [A_1, ..., A_p].

    Args:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y_i = XA_i
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
274
        quant_config: Quantization configure.
James Fleming's avatar
James Fleming committed
275
276
        output_sizes: list of output sizes packed into one output, like for QKV
                       the list would be size 3.
277
278
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj) 
279
280
    """

281
282
283
284
285
286
287
288
    def __init__(self,
                 input_size: int,
                 output_size: int,
                 bias: bool = True,
                 gather_output: bool = False,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
                 quant_config: Optional[QuantizationConfig] = None,
289
                 output_sizes: Optional[List[int]] = None,
290
                 prefix: str = ""):
291
        super().__init__(input_size, output_size, skip_bias_add, params_dtype,
292
                         quant_config, prefix)
293
294

        self.gather_output = gather_output
295

296
297
        # Divide the weight matrix along the last dimension.
        tp_size = get_tensor_model_parallel_world_size()
298
299
300
301
302
303
304
305
306
307
        assert self.quant_method is not None
        self.output_size_per_partition = divide(self.output_size, tp_size)
        self.output_partition_sizes = [self.output_size_per_partition]
        # If QKV or MergedColumn, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = [
                divide(output_size, tp_size)
                for output_size in self.output_sizes
            ]

James Fleming's avatar
James Fleming committed
308
309
        if output_sizes is None:
            output_sizes = [output_size]
310

311
312
313
314
315
316
317
        self.quant_method.create_weights(
            layer=self,
            input_size_per_partition=self.input_size,
            output_partition_sizes=self.output_partition_sizes,
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
318
319
            weight_loader=(
                self.weight_loader_v2 if self.quant_method.__class__.__name__
320
                in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
321
322
323
324
325
326
327
328
329
330
331
332
333
334
        if bias:
            self.bias = Parameter(
                torch.empty(self.output_size_per_partition,
                            dtype=params_dtype))
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
        else:
            self.register_parameter("bias", None)

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        tp_rank = get_tensor_model_parallel_rank()
        output_dim = getattr(param, "output_dim", None)
335
336
337
338
339
340
341
342
343
344
345

        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
            param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)

346
347
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)

348
        param_data = param.data
349
350
351
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow here
        if output_dim is not None and not use_bitsandbytes_4bit:
352
353
354
355
            shard_size = param_data.shape[output_dim]
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                 shard_size)
356
357
358
359
360

        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)
361

362
363
364
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

365
    def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor):
366
367
368
369
370
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)
371
372
        param.load_column_parallel_weight(loaded_weight=loaded_weight)

373
374
375
376
    def forward(self, input_):
        bias = self.bias if not self.skip_bias_add else None

        # Matrix multiply.
377
        assert self.quant_method is not None
378
        output_parallel = self.quant_method.apply(self, input_, bias)
379
380
381
382
383
384
385
386
        if self.gather_output:
            # All-gather across the partitions.
            output = tensor_model_parallel_all_gather(output_parallel)
        else:
            output = output_parallel
        output_bias = self.bias if self.skip_bias_add else None
        return output, output_bias

387
388
389
390
391
392
393
394
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size_per_partition}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={get_tensor_model_parallel_world_size()}"
        s += f", gather_output={self.gather_output}"
        return s

395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413

class MergedColumnParallelLinear(ColumnParallelLinear):
    """Packed linear layers with column parallelism.

    Similar to ColumnParallelLinear, but the weight matrix is concatenated
    along the output dimension. When the weight matrix is loaded, the
    different partitions are sharded separately.

    Args:
        input_size: input dimension of the linear layer.
        output_sizes: list of output dimensions of the linear layer.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make the output
                       available to all GPUs, otherwise, every GPU will have
                       its own output.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
414
        quant_config: Quantization configure.
415
416
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
417
418
    """

419
420
421
422
423
424
425
    def __init__(self,
                 input_size: int,
                 output_sizes: List[int],
                 bias: bool = True,
                 gather_output: bool = False,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
426
                 quant_config: Optional[QuantizationConfig] = None,
427
                 prefix: str = ""):
428
429
430
        self.output_sizes = output_sizes
        tp_size = get_tensor_model_parallel_world_size()
        assert all(output_size % tp_size == 0 for output_size in output_sizes)
431
432
433
434
435
436
        super().__init__(input_size=input_size,
                         output_size=sum(output_sizes),
                         bias=bias,
                         gather_output=gather_output,
                         skip_bias_add=skip_bias_add,
                         params_dtype=params_dtype,
437
438
                         quant_config=quant_config,
                         prefix=prefix)
439
440
441
442
443

    def weight_loader(self,
                      param: Parameter,
                      loaded_weight: torch.Tensor,
                      loaded_shard_id: Optional[int] = None):
James Fleming's avatar
James Fleming committed
444

445
446
447
448
449
        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
450
451
452
453
454
455
456
457
            if loaded_shard_id is not None:
                param.data[loaded_shard_id].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
                param.shard_weight_type = {
                    i: loaded_weight.item()
                    for i, _ in enumerate(self.output_sizes)
                }
458
459
            return

460
461
462
463
464
465
466
        if is_gguf_weight:
            tp_size = get_tensor_model_parallel_world_size()
            tp_rank = get_tensor_model_parallel_rank()

            output_dim = getattr(param, "output_dim", None)
            shard_size = loaded_weight.size(output_dim) // tp_size
            start_idx = tp_rank * shard_size
467

468
469
470
471
472
473
474
475
476
            if loaded_shard_id is not None:
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                if len(param.data_container) == 2:
                    self.qweight = param.materialize_nested()
                return
477

478
479
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
480
        # Special case for AQLM codebooks.
James Fleming's avatar
James Fleming committed
481
        is_metadata = getattr(param, "is_metadata", False)
482
483
        # Special case for per-tensor scale to load scalar into fused array.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
484

485
        if loaded_shard_id is None:
486
487
            # Loaded weight is already fused on disk (mlp).
            # (e.g., Phi-3's gate_up_proj).
488
            if output_dim is None:
489
                if needs_scalar_to_array:
490
491
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
                        param_data, loaded_weight, 0)
492

493
494
495
496
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            current_shard_offset = 0
497
498
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)
499
            shard_offsets: List[Tuple[int, int, int]] = []
500
501
502
503
504
            for i, output_size in enumerate(self.output_sizes):
                shard_offsets.append((i, current_shard_offset, output_size))
                current_shard_offset += output_size
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
505
                # Special case for Quantization.
506
507
508
509
510
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
                    shard_size = shard_size // param.pack_factor
                    shard_offset = shard_offset // param.pack_factor
511
                    # Special case for Marlin.
512
513
                    shard_size, shard_offset = adjust_marlin_shard(
                        param, shard_size, shard_offset)
514

515
                if use_bitsandbytes_4bit:
516
517
518
519
520
521
522
523
524
                    index = list(itertools.accumulate([0] + self.output_sizes))
                    orig_offsets = {
                        str(i): (index[i], size)
                        for i, size in enumerate(self.output_sizes)
                    }
                    orig_offsets["total"] = (self.output_size, 0)
                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
                        param, orig_offsets, str(shard_id))

525
526
527
528
529
530
531
532
533
534
535
                loaded_weight_shard = loaded_weight.narrow(
                    output_dim, shard_offset, shard_size)
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id < len(self.output_sizes)
        tp_rank = get_tensor_model_parallel_rank()
        tp_size = get_tensor_model_parallel_world_size()
        if output_dim is not None:
            shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
            shard_size = self.output_sizes[loaded_shard_id] // tp_size
536
            # Special case for quantization.
537
538
539
540
541
542
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
                shard_size = shard_size // param.pack_factor
                shard_offset = shard_offset // param.pack_factor
543
                # Special case for Marlin.
544
545
546
                shard_size, shard_offset = adjust_marlin_shard(
                    param, shard_size, shard_offset)

547
548
549
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)
            if use_bitsandbytes_4bit:
550
551
552
553
                shard_size = loaded_weight.shape[output_dim]
                shard_offset = loaded_weight.shape[output_dim] * \
                    loaded_shard_id

554
555
556
            param_data = param_data.narrow(output_dim, shard_offset,
                                           shard_size)
            start_idx = tp_rank * shard_size
557
558
559
560
561
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow here
            if not use_bitsandbytes_4bit:
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)
562
        # Special case for AQLM codebooks.
James Fleming's avatar
James Fleming committed
563
564
565
566
567
        elif is_metadata:
            # metadata indicates fixed size concatenated along dim 0
            shard_size = loaded_weight.shape[0]
            shard_offset = loaded_shard_id * shard_size
            param_data = param_data.narrow(0, shard_offset, shard_size)
568

569
570
571
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
572
573
                param_data, loaded_weight, loaded_shard_id)

574
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
575
576
577
578
579
580
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "MergedColumnParallelLinear, assume the weight is "
                    "the same for all partitions.")
581

582
583
584
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
    def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
                                           loaded_weight: torch.Tensor):
        """
        Handle special case for models where MLP layers are already
        fused on disk. In this case, we have no shard id. This function
        determmines the shard id by splitting these layers and then calls
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """

        current_shard_offset = 0
        shard_offsets: List[Tuple[int, int, int]] = []
        for i, output_size in enumerate(self.output_sizes):
            shard_offsets.append((i, current_shard_offset, output_size))
            current_shard_offset += output_size

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
607
608
            if isinstance(param, (PackedColumnParameter, PackedvLLMParameter
                                  )) and param.packed_dim == param.output_dim:
609
610
                shard_size, shard_offset = \
                    param.adjust_shard_indexes_for_packing(
611
612
613
614
615
616
617
618
619
620
621
622
                    shard_size=shard_size, shard_offset=shard_offset)

            loaded_weight_shard = loaded_weight.narrow(param.output_dim,
                                                       shard_offset,
                                                       shard_size)
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

    def weight_loader_v2(self,
                         param: BasevLLMParameter,
                         loaded_weight: torch.Tensor,
                         loaded_shard_id: Optional[int] = None):
        if loaded_shard_id is None:
623
624
625
626
            if isinstance(param, PerTensorScaleParameter):
                param.load_merged_column_weight(loaded_weight=loaded_weight,
                                                shard_id=0)
                return
627
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
628
                param.load_merged_column_weight(loaded_weight=loaded_weight)
629
                return
630
            # TODO: @dsikka - move to parameter.py
631
632
633
634
635
636
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id < len(self.output_sizes)

        tp_size = get_tensor_model_parallel_world_size()
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654

        if isinstance(param, BlockQuantScaleParameter):
            from vllm.model_executor.layers.quantization.fp8 import (
                Fp8LinearMethod, Fp8MoEMethod)
            assert self.quant_method is not None
            assert isinstance(self.quant_method,
                              (Fp8LinearMethod, Fp8MoEMethod))
            weight_block_size = self.quant_method.quant_config.weight_block_size
            assert weight_block_size is not None
            block_n, _ = weight_block_size[0], weight_block_size[1]
            shard_offset = (
                (sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) //
                block_n) // tp_size
            shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) //
                          block_n // tp_size)
        else:
            shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
            shard_size = self.output_sizes[loaded_shard_id] // tp_size
655
656
657
658
659
660

        param.load_merged_column_weight(loaded_weight=loaded_weight,
                                        shard_id=loaded_shard_id,
                                        shard_offset=shard_offset,
                                        shard_size=shard_size)

661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682

class QKVParallelLinear(ColumnParallelLinear):
    """Linear layers for the attention's QKV transformation.

    Linear layers for the linear transformation of the query, key, and value
    vectors in the attention layer. The weight matrix is concatenated along
    the output dimension. The layer is parallelized along the head dimension.
    When the number of key/value heads is smaller than the number of query
    heads (e.g., multi-query/grouped-query attention), the key/value head may
    be replicated while the query heads are partitioned.

    Args:
        hidden_size: input hidden state size of the transformer.
        head_size: size of each attention head.
        total_num_heads: total number of attention query heads.
        total_num_kv_heads: total number of attention key/value heads. If
                            None, assume total_num_kv_heads = total_num_heads.
        bias: If true, add bias.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
683
        quant_config: Quantization configure.
684
685
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
686
687
    """

688
689
690
691
692
693
694
695
    def __init__(self,
                 hidden_size: int,
                 head_size: int,
                 total_num_heads: int,
                 total_num_kv_heads: Optional[int] = None,
                 bias: bool = True,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
696
                 quant_config: Optional[QuantizationConfig] = None,
697
                 prefix: str = ""):
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
        self.hidden_size = hidden_size
        self.head_size = head_size
        self.total_num_heads = total_num_heads
        if total_num_kv_heads is None:
            total_num_kv_heads = total_num_heads
        self.total_num_kv_heads = total_num_kv_heads
        # Divide the weight matrix along the last dimension.
        tp_size = get_tensor_model_parallel_world_size()
        self.num_heads = divide(self.total_num_heads, tp_size)
        if tp_size >= self.total_num_kv_heads:
            self.num_kv_heads = 1
            self.num_kv_head_replicas = divide(tp_size,
                                               self.total_num_kv_heads)
        else:
            self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
            self.num_kv_head_replicas = 1
        input_size = self.hidden_size
        output_size = (self.num_heads +
                       2 * self.num_kv_heads) * tp_size * self.head_size
717
718
719
720
        self.output_sizes = [
            self.num_heads * self.head_size * tp_size,  # q_proj
            self.num_kv_heads * self.head_size * tp_size,  # k_proj
            self.num_kv_heads * self.head_size * tp_size,  # v_proj 
James Fleming's avatar
James Fleming committed
721
722
        ]

723
724
725
726
727
728
        super().__init__(input_size=input_size,
                         output_size=output_size,
                         bias=bias,
                         gather_output=False,
                         skip_bias_add=skip_bias_add,
                         params_dtype=params_dtype,
729
730
                         quant_config=quant_config,
                         prefix=prefix)
731

732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
    def _get_shard_offset_mapping(self, loaded_shard_id: str):
        shard_offset_mapping = {
            "q": 0,
            "k": self.num_heads * self.head_size,
            "v": (self.num_heads + self.num_kv_heads) * self.head_size,
            "total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size
        }
        return shard_offset_mapping.get(loaded_shard_id)

    def _get_shard_size_mapping(self, loaded_shard_id: str):
        shard_size_mapping = {
            "q": self.num_heads * self.head_size,
            "k": self.num_kv_heads * self.head_size,
            "v": self.num_kv_heads * self.head_size,
        }
        return shard_size_mapping.get(loaded_shard_id)

    def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
                                           loaded_weight: torch.Tensor):
        """
        Handle special case for models where QKV layers are already 
        fused on disk. In this case, we have no shard id. This function
        determmines the shard id by splitting these layers and then calls
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """
        shard_offsets = [
            # (shard_id, shard_offset, shard_size)
            ("q", 0, self.total_num_heads * self.head_size),
            ("k", self.total_num_heads * self.head_size,
             self.total_num_kv_heads * self.head_size),
            ("v",
             (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
             self.total_num_kv_heads * self.head_size),
        ]

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
774
775
            if isinstance(param, (PackedColumnParameter, PackedvLLMParameter
                                  )) and param.packed_dim == param.output_dim:
776
777
                shard_size, shard_offset = \
                    param.adjust_shard_indexes_for_packing(
778
779
780
781
782
783
784
785
786
787
788
789
                    shard_size=shard_size, shard_offset=shard_offset)

            loaded_weight_shard = loaded_weight.narrow(param.output_dim,
                                                       shard_offset,
                                                       shard_size)
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

    def weight_loader_v2(self,
                         param: BasevLLMParameter,
                         loaded_weight: torch.Tensor,
                         loaded_shard_id: Optional[str] = None):
        if loaded_shard_id is None:  # special case for certain models
790
            if isinstance(param, PerTensorScaleParameter):
791
                param.load_qkv_weight(loaded_weight=loaded_weight, shard_id=0)
792
                return
793
794
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
                param.load_qkv_weight(loaded_weight=loaded_weight)
795
                return
796
            # TODO: @dsikka - move to parameter.py
797
798
799
800
801
802
803
804
805
806
807
808
809
810
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id in ["q", "k", "v"]

        shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
        shard_size = self._get_shard_size_mapping(loaded_shard_id)

        param.load_qkv_weight(loaded_weight=loaded_weight,
                              num_heads=self.num_kv_head_replicas,
                              shard_id=loaded_shard_id,
                              shard_offset=shard_offset,
                              shard_size=shard_size)

811
812
813
814
    def weight_loader(self,
                      param: Parameter,
                      loaded_weight: torch.Tensor,
                      loaded_shard_id: Optional[str] = None):
815
816
817
818
819

        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
820
        if is_gguf_weight_type:
821
            idx_map = {"q": 0, "k": 1, "v": 2}
822
823
824
825
826
827
828
829
            if loaded_shard_id is not None:
                param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
                param.shard_weight_type = {
                    k: loaded_weight.item()
                    for k in idx_map
                }
830
831
            return

832
833
834
835
836
837
838
839
        if is_gguf_weight:
            tp_size = get_tensor_model_parallel_world_size()
            tp_rank = get_tensor_model_parallel_rank()

            output_dim = getattr(param, "output_dim", None)
            shard_size = loaded_weight.size(output_dim) // tp_size
            start_idx = tp_rank * shard_size

840
841
842
843
844
845
846
847
848
            if loaded_shard_id is not None:
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                if len(param.data_container) == 3:
                    self.qweight = param.materialize_nested()
                return
849

850
851
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
852
        # Special case for AQLM codebooks.
James Fleming's avatar
James Fleming committed
853
        is_metadata = getattr(param, "is_metadata", False)
854

855
856
        # Special case for per-tensor scales in fused case.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
857

858
        if loaded_shard_id is None:
859
860
            # Loaded weight is already fused on disk (qkv).
            # (e.g., Phi-3's qkv_proj).
861
            if output_dim is None:
862
                if needs_scalar_to_array:
863
864
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
                        param_data, loaded_weight, 0)
865

866
867
868
869
870
871
872
873
874
875
876
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            shard_offsets = [
                # (shard_id, shard_offset, shard_size)
                ("q", 0, self.total_num_heads * self.head_size),
                ("k", self.total_num_heads * self.head_size,
                 self.total_num_kv_heads * self.head_size),
                ("v", (self.total_num_heads + self.total_num_kv_heads) *
                 self.head_size, self.total_num_kv_heads * self.head_size),
            ]
877
878
879
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)

880
881
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
882
                # Special case for Quantized Weights.
883
884
885
886
887
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
                    shard_size = shard_size // param.pack_factor
                    shard_offset = shard_offset // param.pack_factor
888

889
                    # Special case for Marlin.
890
891
892
                    shard_size, shard_offset = adjust_marlin_shard(
                        param, shard_size, shard_offset)

893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
                if use_bitsandbytes_4bit:
                    orig_qkv_offsets = {
                        "q": (0, self.total_num_heads * self.head_size),
                        "k": (self.total_num_heads * self.head_size,
                              self.total_num_kv_heads * self.head_size),
                        "v":
                        ((self.total_num_heads + self.total_num_kv_heads) *
                         self.head_size,
                         self.total_num_kv_heads * self.head_size),
                        "total":
                        ((self.total_num_heads + 2 * self.total_num_kv_heads) *
                         self.head_size, 0)
                    }

                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
                        param, orig_qkv_offsets, shard_id)

910
911
912
913
914
915
916
                loaded_weight_shard = loaded_weight.narrow(
                    output_dim, shard_offset, shard_size)
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        tp_rank = get_tensor_model_parallel_rank()
        assert loaded_shard_id in ["q", "k", "v"]
917
918

        # If output dim is defined, use the default loading process.
919
920
921
922
923
924
925
926
927
928
929
        if output_dim is not None:
            if loaded_shard_id == "q":
                shard_offset = 0
                shard_size = self.num_heads * self.head_size
            elif loaded_shard_id == "k":
                shard_offset = self.num_heads * self.head_size
                shard_size = self.num_kv_heads * self.head_size
            elif loaded_shard_id == "v":
                shard_offset = (self.num_heads +
                                self.num_kv_heads) * self.head_size
                shard_size = self.num_kv_heads * self.head_size
930
            # Special case for Quantized Weights.
931
932
933
934
935
936
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
                shard_size = shard_size // param.pack_factor
                shard_offset = shard_offset // param.pack_factor
937

938
                # Special case for Marlin.
939
940
941
                shard_size, shard_offset = adjust_marlin_shard(
                    param, shard_size, shard_offset)

942
943
944
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)
            if use_bitsandbytes_4bit:
945
946
947
948
949
950
951
952
953
954
955
                orig_qkv_offsets = {
                    "q": (0, self.num_heads * self.head_size),
                    "k": (self.num_heads * self.head_size,
                          self.num_kv_heads * self.head_size),
                    "v":
                    ((self.num_heads + self.num_kv_heads) * self.head_size,
                     self.num_kv_heads * self.head_size),
                    "total":
                    ((self.num_heads + 2 * self.num_kv_heads) * self.head_size,
                     0)
                }
956
                shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
957
958
                    param, orig_qkv_offsets, loaded_shard_id)

959
960
            param_data = param_data.narrow(output_dim, shard_offset,
                                           shard_size)
961
962
963
964
            if loaded_shard_id == "q":
                shard_id = tp_rank
            else:
                shard_id = tp_rank // self.num_kv_head_replicas
965
            start_idx = shard_id * shard_size
966
967
968
969
970
971
972

            # bitsandbytes loads the weights of the specific portion
            # no need to narrow here
            if not use_bitsandbytes_4bit:
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)

973
        # Special case for for AQLM codebooks.
James Fleming's avatar
James Fleming committed
974
975
976
977
978
979
        elif is_metadata:
            # metadata indicates fixed size concatenated along dim 0
            shard_size = loaded_weight.shape[0]
            shard_index = ["q", "k", "v"].index(loaded_shard_id)
            param_data = param_data.narrow(0, shard_index * shard_size,
                                           shard_size)
980
981
982
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
983
                param_data, loaded_weight, loaded_shard_id)
984
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
985
986
987
988
989
990
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "QKVParallelLinear, assume the weight is the same "
                    "for all partitions.")
991

992
993
994
995
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)


996
class RowParallelLinear(LinearBase):
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
    """Linear layer with row parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its first dimension and X along its second dimension as:
               -   -
              | A_1 |
              | .   |
          A = | .   |        X = [X_1, ..., X_p]
              | .   |
              | A_p |
               -   -
    Arguments:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias. Note that bias is not parallelized.
        input_is_parallel: If true, we assume that the input is already
                           split across the GPUs and we do not split
                           again.
        skip_bias_add: This was added to enable performance optimization where
                       bias can be fused with other element-wise operations.
                       We skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
1019
        quant_config: Quantization configure.
1020
1021
    """

1022
1023
1024
1025
1026
1027
1028
1029
    def __init__(self,
                 input_size: int,
                 output_size: int,
                 bias: bool = True,
                 input_is_parallel: bool = True,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
                 reduce_results: bool = True,
1030
                 quant_config: Optional[QuantizationConfig] = None,
1031
                 prefix: str = ""):
1032
        super().__init__(input_size, output_size, skip_bias_add, params_dtype,
1033
                         quant_config, prefix)
1034

1035
1036
1037
1038
        self.input_is_parallel = input_is_parallel
        self.reduce_results = reduce_results

        # Divide the weight matrix along the last dimension.
1039
        self.tp_rank = get_tensor_model_parallel_rank()
1040
1041
        self.tp_size = get_tensor_model_parallel_world_size()
        self.input_size_per_partition = divide(input_size, self.tp_size)
1042
        assert self.quant_method is not None
1043

1044
1045
1046
1047
1048
1049
1050
        self.quant_method.create_weights(
            layer=self,
            input_size_per_partition=self.input_size_per_partition,
            output_partition_sizes=[self.output_size],
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
1051
1052
            weight_loader=(
                self.weight_loader_v2 if self.quant_method.__class__.__name__
1053
                in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
1054
1055
1056
1057
1058
1059
        if not reduce_results and (bias and not skip_bias_add):
            raise ValueError("When not reduce the results, adding bias to the "
                             "results can lead to incorrect results")

        if bias:
            self.bias = Parameter(
1060
                torch.empty(self.output_size, dtype=params_dtype))
1061
1062
1063
1064
1065
1066
1067
1068
1069
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
        else:
            self.register_parameter("bias", None)

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        tp_rank = get_tensor_model_parallel_rank()
1070
        tp_size = get_tensor_model_parallel_world_size()
1071
        input_dim = getattr(param, "input_dim", None)
1072
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1073
1074
1075
1076
1077
1078
1079
1080
1081

        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
1082
1083
1084
1085
            weight_shape = list(loaded_weight.shape)
            if input_dim:
                weight_shape[input_dim] = weight_shape[input_dim] // tp_size
            param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
1086

1087
        param_data = param.data
1088
1089
1090
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow here
        if input_dim is not None and not use_bitsandbytes_4bit:
1091
1092
1093
1094
            shard_size = param_data.shape[input_dim]
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(input_dim, start_idx,
                                                 shard_size)
1095

1096
1097
1098
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
1099
1100
            loaded_weight = loaded_weight.reshape(1)

1101
1102
1103
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

1104
1105
    def weight_loader_v2(self, param: BasevLLMParameter,
                         loaded_weight: torch.Tensor):
1106
1107
1108
1109
1110
1111
1112

        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)

1113
1114
        param.load_row_parallel_weight(loaded_weight=loaded_weight)

1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
    def forward(self, input_):
        if self.input_is_parallel:
            input_parallel = input_
        else:
            tp_rank = get_tensor_model_parallel_rank()
            splitted_input = split_tensor_along_last_dim(
                input_, num_partitions=self.tp_size)
            input_parallel = splitted_input[tp_rank].contiguous()

        # Matrix multiply.
1125
        assert self.quant_method is not None
1126
1127
1128
1129
1130
1131
        # Only fuse bias add into GEMM for rank 0 (this ensures that
        # bias will not get added more than once in TP>1 case)
        bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
        output_parallel = self.quant_method.apply(self,
                                                  input_parallel,
                                                  bias=bias_)
1132
        if self.reduce_results and self.tp_size > 1:
1133
            output = tensor_model_parallel_all_reduce(output_parallel)
1134
        else:
1135
1136
1137
            output = output_parallel

        output_bias = self.bias if self.skip_bias_add else None
1138
1139

        return output, output_bias
1140
1141
1142
1143
1144
1145
1146
1147

    def extra_repr(self) -> str:
        s = f"input_features={self.input_size_per_partition}"
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={self.tp_size}"
        s += f", reduce_results={self.reduce_results}"
        return s