linear.py 56.1 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import itertools
4
from abc import abstractmethod
5
from typing import Optional, Union
6
7
8

import torch
import torch.nn.functional as F
9
from torch.nn.parameter import Parameter, UninitializedParameter
10

11
12
13
14
15
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size,
                              split_tensor_along_last_dim,
                              tensor_model_parallel_all_gather,
                              tensor_model_parallel_all_reduce)
16
from vllm.logger import init_logger
17
18
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig, QuantizeMethodBase)
19
# yapf: disable
20
from vllm.model_executor.parameter import (BasevLLMParameter,
21
                                           BlockQuantScaleParameter,
22
                                           PackedColumnParameter,
23
                                           PackedvLLMParameter,
24
25
                                           PerTensorScaleParameter,
                                           RowvLLMParameter)
26
# yapf: enable
27
28
29
30
from vllm.model_executor.utils import set_weight_attrs

logger = init_logger(__name__)

31
WEIGHT_LOADER_V2_SUPPORTED = [
32
    "CompressedTensorsLinearMethod", "AWQMarlinLinearMethod",
33
    "AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod",
34
    "MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod",
35
    "TPUInt8LinearMethod", "GPTQLinearMethod", "FBGEMMFp8LinearMethod",
36
    "ModelOptFp8LinearMethod", "IPEXAWQLinearMethod", "IPEXGPTQLinearMethod",
37
    "HQQMarlinMethod", "QuarkLinearMethod"
38
]
39

40

41
42
43
44
45
46
47
48
def adjust_marlin_shard(param, shard_size, shard_offset):
    marlin_tile_size = getattr(param, "marlin_tile_size", None)
    if marlin_tile_size is None:
        return shard_size, shard_offset

    return shard_size * marlin_tile_size, shard_offset * marlin_tile_size


49
def adjust_bitsandbytes_4bit_shard(param: Parameter,
50
51
                                   shard_offsets: dict[str, tuple[int, int]],
                                   loaded_shard_id: str) -> tuple[int, int]:
52
53
    """Adjust the quantization offsets and sizes for BitsAndBytes sharding."""

54
55
    total, _ = shard_offsets["total"]
    orig_offset, orig_size = shard_offsets[loaded_shard_id]
56
57
58
59
60
61
62
63

    quantized_total = param.data.shape[0]
    quantized_offset = orig_offset * quantized_total // total
    quantized_size = orig_size * quantized_total // total

    return quantized_size, quantized_offset


64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
    """For fused modules (QKV and MLP) we have an array of length
    N that holds 1 scale for each "logical" matrix. So the param
    is an array of length N. The loaded_weight corresponds to 
    one of the shards on disk. Here, we slice the param based on 
    the shard_id for loading.
    """
    qkv_idxs = {"q": 0, "k": 1, "v": 2}

    if isinstance(shard_id, str):
        shard_id = qkv_idxs[shard_id]
    elif not isinstance(shard_id, int):
        raise ValueError(f"Unknown Shard Id {shard_id}")

    # AutoFP8 scales do not have a shape
    # compressed-tensors scales do have a shape
    if len(loaded_weight.shape) != 0:
        assert loaded_weight.shape[0] == 1
        loaded_weight = loaded_weight[0]

    return param[shard_id], loaded_weight


87
class LinearMethodBase(QuantizeMethodBase):
88
89
90
    """Base class for different (maybe quantized) linear methods."""

    @abstractmethod
91
92
    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
93
                       output_partition_sizes: list[int], input_size: int,
94
95
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
96
97
        """Create weights for a linear layer. 
           The weights will be set as attributes of the layer.
98

99
100
101
102
103
104
105
106
107
108
        Args:
            layer: The layer that is using the LinearMethodBase factory.
            input_size_per_partition: Size of the weight input dim on rank X.
            output_partition_sizes: Sizes of the output dim of each logical 
                weight on rank X. E.g., output_partition_sizes for QKVLinear
                is a list contains the width of Wq, Wk, Wv on rank X.
            input_size: Size of the input dim of the weight across all ranks.
            output_size: Size of the output dim of the weight across all ranks.
            params_dtype: Datatype of the parameters.
        """
109
110
111
        raise NotImplementedError

    @abstractmethod
112
113
114
115
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
116
117
        """Apply the weights in layer to the input tensor.
        Expects create_weights to have been called before on the layer."""
118
119
120
121
        raise NotImplementedError


class UnquantizedLinearMethod(LinearMethodBase):
122
    """Linear method without quantization."""
123

124
125
    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
126
                       output_partition_sizes: list[int], input_size: int,
127
128
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
129
        weight = Parameter(torch.empty(sum(output_partition_sizes),
CHU Tianxiang's avatar
CHU Tianxiang committed
130
                                       input_size_per_partition,
131
132
133
                                       dtype=params_dtype),
                           requires_grad=False)
        set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
134
135
        layer.register_parameter("weight", weight)
        set_weight_attrs(weight, extra_weight_attrs)
136

137
138
139
140
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
141
142

        return F.linear(x, layer.weight, bias)
143
144


145
146
class LinearBase(torch.nn.Module):
    """Base linear layer.
147
148
149
150
151
152
153

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        bias: If true, add bias.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
154
        quant_config: Quantization configure.
155
        return_bias: If true, return bias together with outputs in forward pass.
156
157
158
159
160
161
162
163
    """

    def __init__(
        self,
        input_size: int,
        output_size: int,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
164
        quant_config: Optional[QuantizationConfig] = None,
165
        prefix: str = "",
166
167
        *,
        return_bias: bool = True,
168
169
170
171
172
173
174
175
176
177
    ):
        super().__init__()

        # Keep input parameters
        self.input_size = input_size
        self.output_size = output_size
        self.skip_bias_add = skip_bias_add
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype
178
        if quant_config is None:
179
180
            self.quant_method: Optional[
                QuantizeMethodBase] = UnquantizedLinearMethod()
181
        else:
182
183
            self.quant_method = quant_config.get_quant_method(self,
                                                              prefix=prefix)
184
        self.return_bias = return_bias
185

186
187
188
    def forward(
        self, x: torch.Tensor
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
189
190
191
192
193
194
195
196
197
198
199
200
201
        raise NotImplementedError


class ReplicatedLinear(LinearBase):
    """Replicated linear layer.

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        bias: If true, add bias.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
        quant_config: Quantization configure.
202
203
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
204
205
    """

206
207
208
209
210
211
212
213
214
215
216
217
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        *,
        return_bias: bool = True,
    ):
218
219
220
221
222
        super().__init__(input_size,
                         output_size,
                         skip_bias_add,
                         params_dtype,
                         quant_config,
223
224
                         prefix=prefix,
                         return_bias=return_bias)
225

226
227
        # All the linear layer supports quant method.
        assert self.quant_method is not None
228
229
230
231
232
        self.quant_method.create_weights(self,
                                         self.input_size, [self.output_size],
                                         self.input_size,
                                         self.output_size,
                                         self.params_dtype,
233
                                         weight_loader=self.weight_loader)
234

235
236
        if bias:
            self.bias = Parameter(
237
                torch.empty(self.output_size, dtype=self.params_dtype))
238
239
240
241
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
242
243
244
        else:
            self.register_parameter("bias", None)

245
246
247
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        # If the weight on disk does not have a shape, give it one
        # (such scales for AutoFp8).
248
249
250
251
252
253
254
255
256
257
258
        # Special case for GGUF

        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
            param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)

259
260
261
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)

262
263
264
        assert param.size() == loaded_weight.size(), (
            f"Tried to load weights of size {loaded_weight.size()}"
            f"to a parameter of size {param.size()}")
265
266
        param.data.copy_(loaded_weight)

267
268
269
    def forward(
        self, x: torch.Tensor
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
270
        bias = self.bias if not self.skip_bias_add else None
271
        assert self.quant_method is not None
272
        output = self.quant_method.apply(self, x, bias)
273
        output_bias = self.bias if self.skip_bias_add else None
274
275
        if not self.return_bias:
            return output
276
277
        return output, output_bias

278
279
280
281
282
283
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        return s

284

285
class ColumnParallelLinear(LinearBase):
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
    """Linear layer with column parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its second dimension as A = [A_1, ..., A_p].

    Args:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y_i = XA_i
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
302
        quant_config: Quantization configure.
James Fleming's avatar
James Fleming committed
303
304
        output_sizes: list of output sizes packed into one output, like for QKV
                       the list would be size 3.
305
306
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj) 
307
308
    """

309
310
311
312
313
314
315
316
317
318
319
320
321
322
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
        output_sizes: Optional[list[int]] = None,
        prefix: str = "",
        *,
        return_bias: bool = True,
    ):
323
        # Divide the weight matrix along the last dimension.
324
325
326
        self.tp_size = get_tensor_model_parallel_world_size()
        self.input_size_per_partition = input_size
        self.output_size_per_partition = divide(output_size, self.tp_size)
327
328
329
330
        self.output_partition_sizes = [self.output_size_per_partition]
        # If QKV or MergedColumn, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = [
331
                divide(output_size, self.tp_size)
332
333
334
                for output_size in self.output_sizes
            ]

335
336
337
338
339
340
341
        super().__init__(input_size,
                         output_size,
                         skip_bias_add,
                         params_dtype,
                         quant_config,
                         prefix,
                         return_bias=return_bias)
342
343
344

        self.gather_output = gather_output

James Fleming's avatar
James Fleming committed
345
346
        if output_sizes is None:
            output_sizes = [output_size]
347

348
        assert self.quant_method is not None
349
350
        self.quant_method.create_weights(
            layer=self,
351
            input_size_per_partition=self.input_size_per_partition,
352
353
354
355
            output_partition_sizes=self.output_partition_sizes,
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
356
357
            weight_loader=(
                self.weight_loader_v2 if self.quant_method.__class__.__name__
358
                in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
359
360
361
362
363
364
365
366
367
368
369
370
371
372
        if bias:
            self.bias = Parameter(
                torch.empty(self.output_size_per_partition,
                            dtype=params_dtype))
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
        else:
            self.register_parameter("bias", None)

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        tp_rank = get_tensor_model_parallel_rank()
        output_dim = getattr(param, "output_dim", None)
373

374
375
376
377
378
379
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

380
381
382
383
384
385
386
387
        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
388
389
390
391
392
393
            final_shape = list(loaded_weight.shape)
            if output_dim is not None:
                tp_size = get_tensor_model_parallel_world_size()
                assert final_shape[output_dim] % tp_size == 0
                final_shape[output_dim] = final_shape[output_dim] // tp_size
            param.materialize(final_shape, dtype=loaded_weight.dtype)
394

395
        param_data = param.data
396
        if output_dim is not None and not is_sharded_weight:
397
398
399
400
            shard_size = param_data.shape[output_dim]
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                 shard_size)
401
402
403
404
405

        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)
406

407
408
409
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

410
    def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor):
411
412
413
414
415
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)
416
417
        param.load_column_parallel_weight(loaded_weight=loaded_weight)

418
419
420
    def forward(
        self, input_
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
421
422
423
        bias = self.bias if not self.skip_bias_add else None

        # Matrix multiply.
424
        assert self.quant_method is not None
425
        output_parallel = self.quant_method.apply(self, input_, bias)
426
427
428
429
430
431
        if self.gather_output:
            # All-gather across the partitions.
            output = tensor_model_parallel_all_gather(output_parallel)
        else:
            output = output_parallel
        output_bias = self.bias if self.skip_bias_add else None
432
433
        if not self.return_bias:
            return output
434
435
        return output, output_bias

436
437
438
439
440
441
442
443
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size_per_partition}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={get_tensor_model_parallel_world_size()}"
        s += f", gather_output={self.gather_output}"
        return s

444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462

class MergedColumnParallelLinear(ColumnParallelLinear):
    """Packed linear layers with column parallelism.

    Similar to ColumnParallelLinear, but the weight matrix is concatenated
    along the output dimension. When the weight matrix is loaded, the
    different partitions are sharded separately.

    Args:
        input_size: input dimension of the linear layer.
        output_sizes: list of output dimensions of the linear layer.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make the output
                       available to all GPUs, otherwise, every GPU will have
                       its own output.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
463
        quant_config: Quantization configure.
464
465
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
466
467
    """

468
469
470
471
472
473
474
475
476
477
478
479
480
    def __init__(
        self,
        input_size: int,
        output_sizes: list[int],
        bias: bool = True,
        gather_output: bool = False,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        *,
        return_bias: bool = True,
    ):
481
482
483
        self.output_sizes = output_sizes
        tp_size = get_tensor_model_parallel_world_size()
        assert all(output_size % tp_size == 0 for output_size in output_sizes)
484
485
486
487
488
489
        super().__init__(input_size=input_size,
                         output_size=sum(output_sizes),
                         bias=bias,
                         gather_output=gather_output,
                         skip_bias_add=skip_bias_add,
                         params_dtype=params_dtype,
490
                         quant_config=quant_config,
491
492
                         prefix=prefix,
                         return_bias=return_bias)
493
494
495
496
497

    def weight_loader(self,
                      param: Parameter,
                      loaded_weight: torch.Tensor,
                      loaded_shard_id: Optional[int] = None):
James Fleming's avatar
James Fleming committed
498

499
500
501
502
503
        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
504
505
506
507
508
509
510
511
            if loaded_shard_id is not None:
                param.data[loaded_shard_id].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
                param.shard_weight_type = {
                    i: loaded_weight.item()
                    for i, _ in enumerate(self.output_sizes)
                }
512
513
            return

514
515
516
517
518
519
520
        if is_gguf_weight:
            tp_size = get_tensor_model_parallel_world_size()
            tp_rank = get_tensor_model_parallel_rank()

            output_dim = getattr(param, "output_dim", None)
            shard_size = loaded_weight.size(output_dim) // tp_size
            start_idx = tp_rank * shard_size
521

522
523
524
525
526
527
528
529
530
            if loaded_shard_id is not None:
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                if len(param.data_container) == 2:
                    self.qweight = param.materialize_nested()
                return
531

532
533
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
534
        # Special case for AQLM codebooks.
James Fleming's avatar
James Fleming committed
535
        is_metadata = getattr(param, "is_metadata", False)
536
537
        # Special case for per-tensor scale to load scalar into fused array.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
538

539
        if loaded_shard_id is None:
540
541
            # Loaded weight is already fused on disk (mlp).
            # (e.g., Phi-3's gate_up_proj).
542
            if output_dim is None:
543
                if needs_scalar_to_array:
544
545
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
                        param_data, loaded_weight, 0)
546

547
548
549
550
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            current_shard_offset = 0
551
552
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)
553
            shard_offsets: list[tuple[int, int, int]] = []
554
555
556
557
558
            for i, output_size in enumerate(self.output_sizes):
                shard_offsets.append((i, current_shard_offset, output_size))
                current_shard_offset += output_size
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
559
                # Special case for Quantization.
560
561
562
563
564
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
                    shard_size = shard_size // param.pack_factor
                    shard_offset = shard_offset // param.pack_factor
565
                    # Special case for Marlin.
566
567
                    shard_size, shard_offset = adjust_marlin_shard(
                        param, shard_size, shard_offset)
568

569
                if use_bitsandbytes_4bit:
570
571
572
573
574
575
576
577
578
                    index = list(itertools.accumulate([0] + self.output_sizes))
                    orig_offsets = {
                        str(i): (index[i], size)
                        for i, size in enumerate(self.output_sizes)
                    }
                    orig_offsets["total"] = (self.output_size, 0)
                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
                        param, orig_offsets, str(shard_id))

579
580
581
582
583
584
585
586
587
588
589
                loaded_weight_shard = loaded_weight.narrow(
                    output_dim, shard_offset, shard_size)
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id < len(self.output_sizes)
        tp_rank = get_tensor_model_parallel_rank()
        tp_size = get_tensor_model_parallel_world_size()
        if output_dim is not None:
            shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
            shard_size = self.output_sizes[loaded_shard_id] // tp_size
590
            # Special case for quantization.
591
592
593
594
595
596
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
                shard_size = shard_size // param.pack_factor
                shard_offset = shard_offset // param.pack_factor
597
                # Special case for Marlin.
598
599
600
                shard_size, shard_offset = adjust_marlin_shard(
                    param, shard_size, shard_offset)

601
602
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)
603
604
605
606
607
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

608
            if use_bitsandbytes_4bit:
609
610
611
612
                shard_size = loaded_weight.shape[output_dim]
                shard_offset = loaded_weight.shape[output_dim] * \
                    loaded_shard_id

613
614
615
            param_data = param_data.narrow(output_dim, shard_offset,
                                           shard_size)
            start_idx = tp_rank * shard_size
616
            if not is_sharded_weight:
617
618
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)
619
        # Special case for AQLM codebooks.
James Fleming's avatar
James Fleming committed
620
621
622
623
624
        elif is_metadata:
            # metadata indicates fixed size concatenated along dim 0
            shard_size = loaded_weight.shape[0]
            shard_offset = loaded_shard_id * shard_size
            param_data = param_data.narrow(0, shard_offset, shard_size)
625

626
627
628
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
629
630
                param_data, loaded_weight, loaded_shard_id)

631
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
632
633
634
635
636
637
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "MergedColumnParallelLinear, assume the weight is "
                    "the same for all partitions.")
638

639
640
641
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

642
643
644
645
646
647
648
649
650
651
652
653
654
    def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
                                           loaded_weight: torch.Tensor):
        """
        Handle special case for models where MLP layers are already
        fused on disk. In this case, we have no shard id. This function
        determmines the shard id by splitting these layers and then calls
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """

        current_shard_offset = 0
655
        shard_offsets: list[tuple[int, int, int]] = []
656
657
658
659
660
661
662
663
        for i, output_size in enumerate(self.output_sizes):
            shard_offsets.append((i, current_shard_offset, output_size))
            current_shard_offset += output_size

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
664
665
            if isinstance(param, (PackedColumnParameter, PackedvLLMParameter
                                  )) and param.packed_dim == param.output_dim:
666
667
                shard_size, shard_offset = \
                    param.adjust_shard_indexes_for_packing(
668
669
670
671
672
673
674
675
676
677
678
679
                    shard_size=shard_size, shard_offset=shard_offset)

            loaded_weight_shard = loaded_weight.narrow(param.output_dim,
                                                       shard_offset,
                                                       shard_size)
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

    def weight_loader_v2(self,
                         param: BasevLLMParameter,
                         loaded_weight: torch.Tensor,
                         loaded_shard_id: Optional[int] = None):
        if loaded_shard_id is None:
680
681
682
683
            if isinstance(param, PerTensorScaleParameter):
                param.load_merged_column_weight(loaded_weight=loaded_weight,
                                                shard_id=0)
                return
684
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
685
                param.load_merged_column_weight(loaded_weight=loaded_weight)
686
                return
687
            # TODO: @dsikka - move to parameter.py
688
689
690
691
692
693
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id < len(self.output_sizes)

        tp_size = get_tensor_model_parallel_world_size()
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711

        if isinstance(param, BlockQuantScaleParameter):
            from vllm.model_executor.layers.quantization.fp8 import (
                Fp8LinearMethod, Fp8MoEMethod)
            assert self.quant_method is not None
            assert isinstance(self.quant_method,
                              (Fp8LinearMethod, Fp8MoEMethod))
            weight_block_size = self.quant_method.quant_config.weight_block_size
            assert weight_block_size is not None
            block_n, _ = weight_block_size[0], weight_block_size[1]
            shard_offset = (
                (sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) //
                block_n) // tp_size
            shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) //
                          block_n // tp_size)
        else:
            shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
            shard_size = self.output_sizes[loaded_shard_id] // tp_size
712
713
714
715
716
717

        param.load_merged_column_weight(loaded_weight=loaded_weight,
                                        shard_id=loaded_shard_id,
                                        shard_offset=shard_offset,
                                        shard_size=shard_size)

718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739

class QKVParallelLinear(ColumnParallelLinear):
    """Linear layers for the attention's QKV transformation.

    Linear layers for the linear transformation of the query, key, and value
    vectors in the attention layer. The weight matrix is concatenated along
    the output dimension. The layer is parallelized along the head dimension.
    When the number of key/value heads is smaller than the number of query
    heads (e.g., multi-query/grouped-query attention), the key/value head may
    be replicated while the query heads are partitioned.

    Args:
        hidden_size: input hidden state size of the transformer.
        head_size: size of each attention head.
        total_num_heads: total number of attention query heads.
        total_num_kv_heads: total number of attention key/value heads. If
                            None, assume total_num_kv_heads = total_num_heads.
        bias: If true, add bias.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
740
        quant_config: Quantization configure.
741
742
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
743
744
    """

745
746
747
748
749
750
751
752
753
754
755
756
757
758
    def __init__(
        self,
        hidden_size: int,
        head_size: int,
        total_num_heads: int,
        total_num_kv_heads: Optional[int] = None,
        bias: bool = True,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        *,
        return_bias: bool = True,
    ):
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
        self.hidden_size = hidden_size
        self.head_size = head_size
        self.total_num_heads = total_num_heads
        if total_num_kv_heads is None:
            total_num_kv_heads = total_num_heads
        self.total_num_kv_heads = total_num_kv_heads
        # Divide the weight matrix along the last dimension.
        tp_size = get_tensor_model_parallel_world_size()
        self.num_heads = divide(self.total_num_heads, tp_size)
        if tp_size >= self.total_num_kv_heads:
            self.num_kv_heads = 1
            self.num_kv_head_replicas = divide(tp_size,
                                               self.total_num_kv_heads)
        else:
            self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
            self.num_kv_head_replicas = 1
        input_size = self.hidden_size
        output_size = (self.num_heads +
                       2 * self.num_kv_heads) * tp_size * self.head_size
778
779
780
781
        self.output_sizes = [
            self.num_heads * self.head_size * tp_size,  # q_proj
            self.num_kv_heads * self.head_size * tp_size,  # k_proj
            self.num_kv_heads * self.head_size * tp_size,  # v_proj 
James Fleming's avatar
James Fleming committed
782
783
        ]

784
785
786
787
788
789
        super().__init__(input_size=input_size,
                         output_size=output_size,
                         bias=bias,
                         gather_output=False,
                         skip_bias_add=skip_bias_add,
                         params_dtype=params_dtype,
790
                         quant_config=quant_config,
791
792
                         prefix=prefix,
                         return_bias=return_bias)
793

794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
    def _get_shard_offset_mapping(self, loaded_shard_id: str):
        shard_offset_mapping = {
            "q": 0,
            "k": self.num_heads * self.head_size,
            "v": (self.num_heads + self.num_kv_heads) * self.head_size,
            "total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size
        }
        return shard_offset_mapping.get(loaded_shard_id)

    def _get_shard_size_mapping(self, loaded_shard_id: str):
        shard_size_mapping = {
            "q": self.num_heads * self.head_size,
            "k": self.num_kv_heads * self.head_size,
            "v": self.num_kv_heads * self.head_size,
        }
        return shard_size_mapping.get(loaded_shard_id)

    def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
                                           loaded_weight: torch.Tensor):
        """
        Handle special case for models where QKV layers are already 
        fused on disk. In this case, we have no shard id. This function
        determmines the shard id by splitting these layers and then calls
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """
        shard_offsets = [
            # (shard_id, shard_offset, shard_size)
            ("q", 0, self.total_num_heads * self.head_size),
            ("k", self.total_num_heads * self.head_size,
             self.total_num_kv_heads * self.head_size),
            ("v",
             (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
             self.total_num_kv_heads * self.head_size),
        ]

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
836
837
            if isinstance(param, (PackedColumnParameter, PackedvLLMParameter
                                  )) and param.packed_dim == param.output_dim:
838
839
                shard_size, shard_offset = \
                    param.adjust_shard_indexes_for_packing(
840
841
842
843
844
845
846
847
848
849
850
851
                    shard_size=shard_size, shard_offset=shard_offset)

            loaded_weight_shard = loaded_weight.narrow(param.output_dim,
                                                       shard_offset,
                                                       shard_size)
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

    def weight_loader_v2(self,
                         param: BasevLLMParameter,
                         loaded_weight: torch.Tensor,
                         loaded_shard_id: Optional[str] = None):
        if loaded_shard_id is None:  # special case for certain models
852
            if isinstance(param, PerTensorScaleParameter):
853
                param.load_qkv_weight(loaded_weight=loaded_weight, shard_id=0)
854
                return
855
856
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
                param.load_qkv_weight(loaded_weight=loaded_weight)
857
                return
858
            # TODO: @dsikka - move to parameter.py
859
860
861
862
863
864
865
866
867
868
869
870
871
872
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id in ["q", "k", "v"]

        shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
        shard_size = self._get_shard_size_mapping(loaded_shard_id)

        param.load_qkv_weight(loaded_weight=loaded_weight,
                              num_heads=self.num_kv_head_replicas,
                              shard_id=loaded_shard_id,
                              shard_offset=shard_offset,
                              shard_size=shard_size)

873
874
875
876
    def weight_loader(self,
                      param: Parameter,
                      loaded_weight: torch.Tensor,
                      loaded_shard_id: Optional[str] = None):
877
878
879
880
881

        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
882
        if is_gguf_weight_type:
883
            idx_map = {"q": 0, "k": 1, "v": 2}
884
885
886
887
888
889
890
891
            if loaded_shard_id is not None:
                param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
                param.shard_weight_type = {
                    k: loaded_weight.item()
                    for k in idx_map
                }
892
893
            return

894
895
896
897
898
899
900
901
        if is_gguf_weight:
            tp_size = get_tensor_model_parallel_world_size()
            tp_rank = get_tensor_model_parallel_rank()

            output_dim = getattr(param, "output_dim", None)
            shard_size = loaded_weight.size(output_dim) // tp_size
            start_idx = tp_rank * shard_size

902
903
904
905
906
907
908
909
910
            if loaded_shard_id is not None:
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                if len(param.data_container) == 3:
                    self.qweight = param.materialize_nested()
                return
911

912
913
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
914
        # Special case for AQLM codebooks.
James Fleming's avatar
James Fleming committed
915
        is_metadata = getattr(param, "is_metadata", False)
916

917
918
        # Special case for per-tensor scales in fused case.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
919

920
        if loaded_shard_id is None:
921
922
            # Loaded weight is already fused on disk (qkv).
            # (e.g., Phi-3's qkv_proj).
923
            if output_dim is None:
924
                if needs_scalar_to_array:
925
926
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
                        param_data, loaded_weight, 0)
927

928
929
930
931
932
933
934
935
936
937
938
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            shard_offsets = [
                # (shard_id, shard_offset, shard_size)
                ("q", 0, self.total_num_heads * self.head_size),
                ("k", self.total_num_heads * self.head_size,
                 self.total_num_kv_heads * self.head_size),
                ("v", (self.total_num_heads + self.total_num_kv_heads) *
                 self.head_size, self.total_num_kv_heads * self.head_size),
            ]
939
940
941
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)

942
943
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
944
                # Special case for Quantized Weights.
945
946
947
948
949
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
                    shard_size = shard_size // param.pack_factor
                    shard_offset = shard_offset // param.pack_factor
950

951
                    # Special case for Marlin.
952
953
954
                    shard_size, shard_offset = adjust_marlin_shard(
                        param, shard_size, shard_offset)

955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
                if use_bitsandbytes_4bit:
                    orig_qkv_offsets = {
                        "q": (0, self.total_num_heads * self.head_size),
                        "k": (self.total_num_heads * self.head_size,
                              self.total_num_kv_heads * self.head_size),
                        "v":
                        ((self.total_num_heads + self.total_num_kv_heads) *
                         self.head_size,
                         self.total_num_kv_heads * self.head_size),
                        "total":
                        ((self.total_num_heads + 2 * self.total_num_kv_heads) *
                         self.head_size, 0)
                    }

                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
                        param, orig_qkv_offsets, shard_id)

972
973
974
975
976
977
978
                loaded_weight_shard = loaded_weight.narrow(
                    output_dim, shard_offset, shard_size)
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        tp_rank = get_tensor_model_parallel_rank()
        assert loaded_shard_id in ["q", "k", "v"]
979
980

        # If output dim is defined, use the default loading process.
981
982
983
984
985
986
987
988
989
990
991
        if output_dim is not None:
            if loaded_shard_id == "q":
                shard_offset = 0
                shard_size = self.num_heads * self.head_size
            elif loaded_shard_id == "k":
                shard_offset = self.num_heads * self.head_size
                shard_size = self.num_kv_heads * self.head_size
            elif loaded_shard_id == "v":
                shard_offset = (self.num_heads +
                                self.num_kv_heads) * self.head_size
                shard_size = self.num_kv_heads * self.head_size
992
            # Special case for Quantized Weights.
993
994
995
996
997
998
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
                shard_size = shard_size // param.pack_factor
                shard_offset = shard_offset // param.pack_factor
999

1000
                # Special case for Marlin.
1001
1002
1003
                shard_size, shard_offset = adjust_marlin_shard(
                    param, shard_size, shard_offset)

1004
1005
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)
1006
1007
1008
1009
1010
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

1011
            if use_bitsandbytes_4bit:
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
                orig_qkv_offsets = {
                    "q": (0, self.num_heads * self.head_size),
                    "k": (self.num_heads * self.head_size,
                          self.num_kv_heads * self.head_size),
                    "v":
                    ((self.num_heads + self.num_kv_heads) * self.head_size,
                     self.num_kv_heads * self.head_size),
                    "total":
                    ((self.num_heads + 2 * self.num_kv_heads) * self.head_size,
                     0)
                }
1023
                shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
1024
1025
                    param, orig_qkv_offsets, loaded_shard_id)

1026
1027
            param_data = param_data.narrow(output_dim, shard_offset,
                                           shard_size)
1028
1029
1030
1031
            if loaded_shard_id == "q":
                shard_id = tp_rank
            else:
                shard_id = tp_rank // self.num_kv_head_replicas
1032
            start_idx = shard_id * shard_size
1033

1034
            if not is_sharded_weight:
1035
1036
1037
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)

1038
        # Special case for for AQLM codebooks.
James Fleming's avatar
James Fleming committed
1039
1040
1041
1042
1043
1044
        elif is_metadata:
            # metadata indicates fixed size concatenated along dim 0
            shard_size = loaded_weight.shape[0]
            shard_index = ["q", "k", "v"].index(loaded_shard_id)
            param_data = param_data.narrow(0, shard_index * shard_size,
                                           shard_size)
1045
1046
1047
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
1048
                param_data, loaded_weight, loaded_shard_id)
1049
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
1050
1051
1052
1053
1054
1055
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "QKVParallelLinear, assume the weight is the same "
                    "for all partitions.")
1056

1057
1058
1059
1060
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)


1061
class RowParallelLinear(LinearBase):
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
    """Linear layer with row parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its first dimension and X along its second dimension as:
               -   -
              | A_1 |
              | .   |
          A = | .   |        X = [X_1, ..., X_p]
              | .   |
              | A_p |
               -   -
    Arguments:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias. Note that bias is not parallelized.
        input_is_parallel: If true, we assume that the input is already
                           split across the GPUs and we do not split
                           again.
        skip_bias_add: This was added to enable performance optimization where
                       bias can be fused with other element-wise operations.
                       We skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
1084
        quant_config: Quantization configure.
1085
1086
    """

1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
    def __init__(
        self,
        input_size: int,
        output_size: int,
        bias: bool = True,
        input_is_parallel: bool = True,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        reduce_results: bool = True,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        *,
        return_bias: bool = True,
    ):
1101
1102
1103
1104
1105
1106
1107
        # Divide the weight matrix along the first dimension.
        self.tp_rank = get_tensor_model_parallel_rank()
        self.tp_size = get_tensor_model_parallel_world_size()
        self.input_size_per_partition = divide(input_size, self.tp_size)
        self.output_size_per_partition = output_size
        self.output_partition_sizes = [output_size]

1108
1109
1110
1111
1112
1113
1114
        super().__init__(input_size,
                         output_size,
                         skip_bias_add,
                         params_dtype,
                         quant_config,
                         prefix,
                         return_bias=return_bias)
1115

1116
1117
1118
        self.input_is_parallel = input_is_parallel
        self.reduce_results = reduce_results

1119
        assert self.quant_method is not None
1120
1121
1122
        self.quant_method.create_weights(
            layer=self,
            input_size_per_partition=self.input_size_per_partition,
1123
            output_partition_sizes=self.output_partition_sizes,
1124
1125
1126
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
1127
1128
            weight_loader=(
                self.weight_loader_v2 if self.quant_method.__class__.__name__
1129
                in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
1130
1131
1132
1133
1134
1135
        if not reduce_results and (bias and not skip_bias_add):
            raise ValueError("When not reduce the results, adding bias to the "
                             "results can lead to incorrect results")

        if bias:
            self.bias = Parameter(
1136
                torch.empty(self.output_size, dtype=params_dtype))
1137
1138
1139
1140
1141
1142
1143
1144
1145
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
        else:
            self.register_parameter("bias", None)

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        tp_rank = get_tensor_model_parallel_rank()
1146
        tp_size = get_tensor_model_parallel_world_size()
1147
        input_dim = getattr(param, "input_dim", None)
1148
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1149
1150
1151
1152
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
1153
1154
1155
1156
1157
1158
1159
1160
1161

        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
1162
1163
1164
1165
            weight_shape = list(loaded_weight.shape)
            if input_dim:
                weight_shape[input_dim] = weight_shape[input_dim] // tp_size
            param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
1166

1167
        param_data = param.data
1168
        if input_dim is not None and not is_sharded_weight:
1169
1170
1171
1172
            shard_size = param_data.shape[input_dim]
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(input_dim, start_idx,
                                                 shard_size)
1173

1174
1175
1176
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
1177
1178
            loaded_weight = loaded_weight.reshape(1)

1179
1180
1181
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

1182
1183
    def weight_loader_v2(self, param: BasevLLMParameter,
                         loaded_weight: torch.Tensor):
1184
1185
1186
1187
1188
1189
1190

        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)

1191
1192
        param.load_row_parallel_weight(loaded_weight=loaded_weight)

1193
1194
1195
    def forward(
        self, input_
    ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
1196
1197
1198
1199
1200
1201
1202
1203
1204
        if self.input_is_parallel:
            input_parallel = input_
        else:
            tp_rank = get_tensor_model_parallel_rank()
            splitted_input = split_tensor_along_last_dim(
                input_, num_partitions=self.tp_size)
            input_parallel = splitted_input[tp_rank].contiguous()

        # Matrix multiply.
1205
        assert self.quant_method is not None
1206
1207
1208
1209
1210
1211
        # Only fuse bias add into GEMM for rank 0 (this ensures that
        # bias will not get added more than once in TP>1 case)
        bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
        output_parallel = self.quant_method.apply(self,
                                                  input_parallel,
                                                  bias=bias_)
1212
        if self.reduce_results and self.tp_size > 1:
1213
            output = tensor_model_parallel_all_reduce(output_parallel)
1214
        else:
1215
1216
1217
            output = output_parallel

        output_bias = self.bias if self.skip_bias_add else None
1218

1219
1220
        if not self.return_bias:
            return output
1221
        return output, output_bias
1222
1223
1224
1225
1226
1227
1228
1229

    def extra_repr(self) -> str:
        s = f"input_features={self.input_size_per_partition}"
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={self.tp_size}"
        s += f", reduce_results={self.reduce_results}"
        return s
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324


class QKVCrossParallelLinear(torch.nn.Module):

    def __init__(self,
                 hidden_size: int,
                 head_size: int,
                 total_num_heads: int,
                 total_num_kv_heads: Optional[int] = None,
                 bias: bool = True,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
        super().__init__()
        # Empty placeholders for loading as a single module.
        self.weight = torch.nn.Parameter()
        set_weight_attrs(self.weight, {
            "weight_loader": self.weight_loader_weight,
        })
        # Use a dictionary to avoid submodules parameters auto-registration:
        # drop-in replacement for a `QKVParallelLinear` module.
        self.proj = dict()
        self.proj["q_proj_decoder"] = ColumnParallelLinear(
            input_size=hidden_size,
            output_size=total_num_heads * head_size,
            bias=bias,
            quant_config=quant_config,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            prefix=f"{prefix}.q_proj_decoder")

        self.proj["kv_proj_encoder"] = QKVParallelLinear(
            hidden_size=hidden_size,
            head_size=head_size,
            total_num_heads=0,
            total_num_kv_heads=total_num_kv_heads,
            bias=bias,
            quant_config=quant_config,
            skip_bias_add=skip_bias_add,
            params_dtype=params_dtype,
            prefix=f"{prefix}.kv_proj_encoder")

        # `kv_proj_encoder.num_kv_heads` accounts for sharding with tp>1.
        self.kv_size = self.kv_proj_encoder.num_kv_heads * head_size

        if bias:
            self.bias = torch.nn.Parameter()
            set_weight_attrs(self.bias, {
                "weight_loader": self.weight_loader_bias,
            })

    @property
    def q_proj_decoder(self):
        return self.proj["q_proj_decoder"]

    @property
    def kv_proj_encoder(self):
        return self.proj["kv_proj_encoder"]

    def forward(self, decoder_hidden_states, encoder_hidden_states):
        q, _ = self.q_proj_decoder(decoder_hidden_states)
        if encoder_hidden_states is None:
            # Encoder KV already cached.
            k = None
            v = None
        else:
            # Prefill phase, encoder KV cached here.
            kv_enc, _ = self.kv_proj_encoder(encoder_hidden_states)
            # Split kv in half
            k, v = kv_enc.split(self.kv_size, dim=-1)
        return q, k, v

    def weight_loader_weight(self,
                             param: torch.nn.Parameter,
                             loaded_weight: torch.Tensor,
                             loaded_shard_id: Optional[str] = None):
        # NOTE Use QKV/ColumnParallel weight_loader, ignore placeholder param.
        param = self.q_proj_decoder.weight if loaded_shard_id == "q" \
            else self.kv_proj_encoder.weight
        param.weight_loader(
            param,
            loaded_weight) if loaded_shard_id == "q" else param.weight_loader(
                param, loaded_weight, loaded_shard_id)

    def weight_loader_bias(self,
                           param: torch.nn.Parameter,
                           loaded_weight: torch.Tensor,
                           loaded_shard_id: Optional[str] = None):
        param = self.q_proj_decoder.bias if loaded_shard_id == "q" \
            else self.kv_proj_encoder.bias
        param.weight_loader(
            param,
            loaded_weight) if loaded_shard_id == "q" else param.weight_loader(
                param, loaded_weight, loaded_shard_id)