linear.py 51.4 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import itertools
4
from abc import abstractmethod
5
from typing import Optional
6
7
8

import torch
import torch.nn.functional as F
9
from torch.nn.parameter import Parameter, UninitializedParameter
10

11
12
13
14
15
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size,
                              split_tensor_along_last_dim,
                              tensor_model_parallel_all_gather,
                              tensor_model_parallel_all_reduce)
16
from vllm.logger import init_logger
17
18
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig, QuantizeMethodBase)
19
# yapf: disable
20
from vllm.model_executor.parameter import (BasevLLMParameter,
21
                                           BlockQuantScaleParameter,
22
                                           PackedColumnParameter,
23
                                           PackedvLLMParameter,
24
25
                                           PerTensorScaleParameter,
                                           RowvLLMParameter)
26
# yapf: enable
27
from vllm.model_executor.utils import set_weight_attrs
gaoqiong's avatar
gaoqiong committed
28

zhuwenwen's avatar
zhuwenwen committed
29
import os
30
from vllm.model_executor.utils import gemm_bank_conf
31
32
33

logger = init_logger(__name__)

34
WEIGHT_LOADER_V2_SUPPORTED = [
35
    "CompressedTensorsLinearMethod", "AWQMarlinLinearMethod",
36
    "AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod",
37
    "MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod",
38
    "TPUInt8LinearMethod", "GPTQLinearMethod", "FBGEMMFp8LinearMethod",
39
    "ModelOptFp8LinearMethod", "IPEXAWQLinearMethod", "IPEXGPTQLinearMethod",
40
    "HQQMarlinMethod", "QuarkLinearMethod"
41
]
42

43

44
45
46
47
48
49
50
51
def adjust_marlin_shard(param, shard_size, shard_offset):
    marlin_tile_size = getattr(param, "marlin_tile_size", None)
    if marlin_tile_size is None:
        return shard_size, shard_offset

    return shard_size * marlin_tile_size, shard_offset * marlin_tile_size


52
def adjust_bitsandbytes_4bit_shard(param: Parameter,
53
54
                                   shard_offsets: dict[str, tuple[int, int]],
                                   loaded_shard_id: str) -> tuple[int, int]:
55
56
    """Adjust the quantization offsets and sizes for BitsAndBytes sharding."""

57
58
    total, _ = shard_offsets["total"]
    orig_offset, orig_size = shard_offsets[loaded_shard_id]
59
60
61
62
63
64
65
66

    quantized_total = param.data.shape[0]
    quantized_offset = orig_offset * quantized_total // total
    quantized_size = orig_size * quantized_total // total

    return quantized_size, quantized_offset


67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
    """For fused modules (QKV and MLP) we have an array of length
    N that holds 1 scale for each "logical" matrix. So the param
    is an array of length N. The loaded_weight corresponds to 
    one of the shards on disk. Here, we slice the param based on 
    the shard_id for loading.
    """
    qkv_idxs = {"q": 0, "k": 1, "v": 2}

    if isinstance(shard_id, str):
        shard_id = qkv_idxs[shard_id]
    elif not isinstance(shard_id, int):
        raise ValueError(f"Unknown Shard Id {shard_id}")

    # AutoFP8 scales do not have a shape
    # compressed-tensors scales do have a shape
    if len(loaded_weight.shape) != 0:
        assert loaded_weight.shape[0] == 1
        loaded_weight = loaded_weight[0]

    return param[shard_id], loaded_weight


90
class LinearMethodBase(QuantizeMethodBase):
91
92
93
    """Base class for different (maybe quantized) linear methods."""

    @abstractmethod
94
95
    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
96
                       output_partition_sizes: list[int], input_size: int,
97
98
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
99
100
        """Create weights for a linear layer. 
           The weights will be set as attributes of the layer.
101

102
103
104
105
106
107
108
109
110
111
        Args:
            layer: The layer that is using the LinearMethodBase factory.
            input_size_per_partition: Size of the weight input dim on rank X.
            output_partition_sizes: Sizes of the output dim of each logical 
                weight on rank X. E.g., output_partition_sizes for QKVLinear
                is a list contains the width of Wq, Wk, Wv on rank X.
            input_size: Size of the input dim of the weight across all ranks.
            output_size: Size of the output dim of the weight across all ranks.
            params_dtype: Datatype of the parameters.
        """
112
113
114
        raise NotImplementedError

    @abstractmethod
115
116
117
118
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
119
120
        """Apply the weights in layer to the input tensor.
        Expects create_weights to have been called before on the layer."""
121
122
123
124
        raise NotImplementedError


class UnquantizedLinearMethod(LinearMethodBase):
125
    """Linear method without quantization."""
126
127
    
    def __init__(self):
zhuwenwen's avatar
zhuwenwen committed
128
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
129
        self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
gaoqiong's avatar
gaoqiong committed
130
        
131
132
    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
133
                       output_partition_sizes: list[int], input_size: int,
134
135
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
136
        weight = Parameter(torch.empty(sum(output_partition_sizes),
CHU Tianxiang's avatar
CHU Tianxiang committed
137
                                       input_size_per_partition,
138
139
140
                                       dtype=params_dtype),
                           requires_grad=False)
        set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
141
142
        layer.register_parameter("weight", weight)
        set_weight_attrs(weight, extra_weight_attrs)
143

144
145
146
147
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
148

zhuwenwen's avatar
zhuwenwen committed
149
        if self.use_llama_nn:
150
151
            if gemm_bank_conf(layer.weight.shape[1] - 32) and os.environ['GEMM_PAD'] == '1':
                layer.weight = layer.weight[:,:-32]
152
                
zhuwenwen's avatar
zhuwenwen committed
153
            if bias is not None:
zhuwenwen's avatar
zhuwenwen committed
154
                if len(x.shape) == 2: 
155
                    return torch.addmm(bias, x, layer.weight)
zhuwenwen's avatar
zhuwenwen committed
156
                else:
157
                    return torch.matmul(x, layer.weight) + bias
zhuwenwen's avatar
zhuwenwen committed
158
            else:
159
                return torch.matmul(x, layer.weight)
zhuwenwen's avatar
zhuwenwen committed
160
        else:
161
            return F.linear(x, layer.weight, bias)
162

163

164
165
class LinearBase(torch.nn.Module):
    """Base linear layer.
166
167
168
169
170
171
172

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        bias: If true, add bias.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
173
        quant_config: Quantization configure.
174
175
176
177
178
179
180
181
    """

    def __init__(
        self,
        input_size: int,
        output_size: int,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
182
        quant_config: Optional[QuantizationConfig] = None,
183
        prefix: str = "",
184
185
186
187
188
189
190
191
192
193
    ):
        super().__init__()

        # Keep input parameters
        self.input_size = input_size
        self.output_size = output_size
        self.skip_bias_add = skip_bias_add
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype
194
        if quant_config is None:
195
196
            self.quant_method: Optional[
                QuantizeMethodBase] = UnquantizedLinearMethod()
197
        else:
198
199
            self.quant_method = quant_config.get_quant_method(self,
                                                              prefix=prefix)
200

201
202
    def forward(self,
                x: torch.Tensor) -> tuple[torch.Tensor, Optional[Parameter]]:
203
204
205
206
207
208
209
210
211
212
213
214
215
        raise NotImplementedError


class ReplicatedLinear(LinearBase):
    """Replicated linear layer.

    Args:
        input_size: input dimension of the linear layer.
        output_size: output dimension of the linear layer.
        bias: If true, add bias.
        skip_bias_add: If true, skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
        quant_config: Quantization configure.
216
217
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
218
219
    """

220
221
222
223
224
225
    def __init__(self,
                 input_size: int,
                 output_size: int,
                 bias: bool = True,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
226
                 quant_config: Optional[QuantizationConfig] = None,
227
228
229
230
231
232
233
                 prefix: str = ""):
        super().__init__(input_size,
                         output_size,
                         skip_bias_add,
                         params_dtype,
                         quant_config,
                         prefix=prefix)
234

235
236
        # All the linear layer supports quant method.
        assert self.quant_method is not None
237
238
239
240
241
        self.quant_method.create_weights(self,
                                         self.input_size, [self.output_size],
                                         self.input_size,
                                         self.output_size,
                                         self.params_dtype,
242
                                         weight_loader=self.weight_loader)
243

244
245
        if bias:
            self.bias = Parameter(
246
                torch.empty(self.output_size, dtype=self.params_dtype))
247
248
249
250
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
251
252
253
        else:
            self.register_parameter("bias", None)

254
255
256
257
258
259
260
261
262
    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        # If the weight on disk does not have a shape, give it one
        # (such scales for AutoFp8).
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)

        assert param.size() == loaded_weight.size()
        param.data.copy_(loaded_weight)

263
264
    def forward(self,
                x: torch.Tensor) -> tuple[torch.Tensor, Optional[Parameter]]:
265
        bias = self.bias if not self.skip_bias_add else None
266
        assert self.quant_method is not None
267
        output = self.quant_method.apply(self, x, bias)
268
269
270
        output_bias = self.bias if self.skip_bias_add else None
        return output, output_bias

271
272
273
274
275
276
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        return s

277

278
class ColumnParallelLinear(LinearBase):
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
    """Linear layer with column parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its second dimension as A = [A_1, ..., A_p].

    Args:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make Y available
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y_i = XA_i
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
295
        quant_config: Quantization configure.
James Fleming's avatar
James Fleming committed
296
297
        output_sizes: list of output sizes packed into one output, like for QKV
                       the list would be size 3.
298
299
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj) 
300
301
    """

302
303
304
305
306
307
308
309
    def __init__(self,
                 input_size: int,
                 output_size: int,
                 bias: bool = True,
                 gather_output: bool = False,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
                 quant_config: Optional[QuantizationConfig] = None,
310
                 output_sizes: Optional[list[int]] = None,
311
                 prefix: str = ""):
312
        super().__init__(input_size, output_size, skip_bias_add, params_dtype,
313
                         quant_config, prefix)
314
315

        self.gather_output = gather_output
316

317
318
        # Divide the weight matrix along the last dimension.
        tp_size = get_tensor_model_parallel_world_size()
319
320
321
322
323
324
325
326
327
328
        assert self.quant_method is not None
        self.output_size_per_partition = divide(self.output_size, tp_size)
        self.output_partition_sizes = [self.output_size_per_partition]
        # If QKV or MergedColumn, use output size of each partition.
        if hasattr(self, "output_sizes"):
            self.output_partition_sizes = [
                divide(output_size, tp_size)
                for output_size in self.output_sizes
            ]

James Fleming's avatar
James Fleming committed
329
330
        if output_sizes is None:
            output_sizes = [output_size]
331

332
333
334
335
336
337
338
        self.quant_method.create_weights(
            layer=self,
            input_size_per_partition=self.input_size,
            output_partition_sizes=self.output_partition_sizes,
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
339
340
            weight_loader=(
                self.weight_loader_v2 if self.quant_method.__class__.__name__
341
                in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
342
343
344
345
346
347
348
349
350
351
352
353
354
355
        if bias:
            self.bias = Parameter(
                torch.empty(self.output_size_per_partition,
                            dtype=params_dtype))
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
        else:
            self.register_parameter("bias", None)

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        tp_rank = get_tensor_model_parallel_rank()
        output_dim = getattr(param, "output_dim", None)
356
357
358
359
360
361
362
363
364
365
366

        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
            param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)

367
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
368
369
370
371
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
372

373
        param_data = param.data
374
        if output_dim is not None and not is_sharded_weight:
375
376
377
378
            shard_size = param_data.shape[output_dim]
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                 shard_size)
379
380
381
382
383

        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)
384

385
386
387
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

388
    def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor):
389
390
391
392
393
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)
394
395
        param.load_column_parallel_weight(loaded_weight=loaded_weight)

396
    def forward(self, input_) -> tuple[torch.Tensor, Optional[Parameter]]:
397
398
399
        bias = self.bias if not self.skip_bias_add else None

        # Matrix multiply.
400
        assert self.quant_method is not None
401
        output_parallel = self.quant_method.apply(self, input_, bias)
402
403
404
405
406
407
408
409
        if self.gather_output:
            # All-gather across the partitions.
            output = tensor_model_parallel_all_gather(output_parallel)
        else:
            output = output_parallel
        output_bias = self.bias if self.skip_bias_add else None
        return output, output_bias

410
411
412
413
414
415
416
417
    def extra_repr(self) -> str:
        s = f"in_features={self.input_size}"
        s += f", output_features={self.output_size_per_partition}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={get_tensor_model_parallel_world_size()}"
        s += f", gather_output={self.gather_output}"
        return s

418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436

class MergedColumnParallelLinear(ColumnParallelLinear):
    """Packed linear layers with column parallelism.

    Similar to ColumnParallelLinear, but the weight matrix is concatenated
    along the output dimension. When the weight matrix is loaded, the
    different partitions are sharded separately.

    Args:
        input_size: input dimension of the linear layer.
        output_sizes: list of output dimensions of the linear layer.
        bias: If true, add bias.
        gather_output: If true, call all-gather on output and make the output
                       available to all GPUs, otherwise, every GPU will have
                       its own output.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
437
        quant_config: Quantization configure.
438
439
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
440
441
    """

442
443
    def __init__(self,
                 input_size: int,
444
                 output_sizes: list[int],
445
446
447
448
                 bias: bool = True,
                 gather_output: bool = False,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
449
                 quant_config: Optional[QuantizationConfig] = None,
450
                 prefix: str = ""):
451
452
453
        self.output_sizes = output_sizes
        tp_size = get_tensor_model_parallel_world_size()
        assert all(output_size % tp_size == 0 for output_size in output_sizes)
454
455
456
457
458
459
        super().__init__(input_size=input_size,
                         output_size=sum(output_sizes),
                         bias=bias,
                         gather_output=gather_output,
                         skip_bias_add=skip_bias_add,
                         params_dtype=params_dtype,
460
461
                         quant_config=quant_config,
                         prefix=prefix)
462
463
464
465
466

    def weight_loader(self,
                      param: Parameter,
                      loaded_weight: torch.Tensor,
                      loaded_shard_id: Optional[int] = None):
James Fleming's avatar
James Fleming committed
467

468
469
470
471
472
        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
473
474
475
476
477
478
479
480
            if loaded_shard_id is not None:
                param.data[loaded_shard_id].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
                param.shard_weight_type = {
                    i: loaded_weight.item()
                    for i, _ in enumerate(self.output_sizes)
                }
481
482
            return

483
484
485
        if is_gguf_weight:
            tp_size = get_tensor_model_parallel_world_size()
            tp_rank = get_tensor_model_parallel_rank()
486

487
488
489
            output_dim = getattr(param, "output_dim", None)
            shard_size = loaded_weight.size(output_dim) // tp_size
            start_idx = tp_rank * shard_size
490

491
492
493
494
495
496
497
498
499
            if loaded_shard_id is not None:
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                if len(param.data_container) == 2:
                    self.qweight = param.materialize_nested()
                return
500

501
502
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
503
        # Special case for AQLM codebooks.
James Fleming's avatar
James Fleming committed
504
        is_metadata = getattr(param, "is_metadata", False)
505
506
        # Special case for per-tensor scale to load scalar into fused array.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
507

508
        if loaded_shard_id is None:
509
510
            # Loaded weight is already fused on disk (mlp).
            # (e.g., Phi-3's gate_up_proj).
511
            if output_dim is None:
512
                if needs_scalar_to_array:
513
514
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
                        param_data, loaded_weight, 0)
515

516
517
518
519
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            current_shard_offset = 0
520
521
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)
522
            shard_offsets: list[tuple[int, int, int]] = []
523
524
525
526
527
            for i, output_size in enumerate(self.output_sizes):
                shard_offsets.append((i, current_shard_offset, output_size))
                current_shard_offset += output_size
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
528
                # Special case for Quantization.
529
530
531
532
533
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
                    shard_size = shard_size // param.pack_factor
                    shard_offset = shard_offset // param.pack_factor
534
                    # Special case for Marlin.
535
536
537
                    shard_size, shard_offset = adjust_marlin_shard(
                        param, shard_size, shard_offset)

538
                if use_bitsandbytes_4bit:
539
540
541
542
543
544
545
546
547
                    index = list(itertools.accumulate([0] + self.output_sizes))
                    orig_offsets = {
                        str(i): (index[i], size)
                        for i, size in enumerate(self.output_sizes)
                    }
                    orig_offsets["total"] = (self.output_size, 0)
                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
                        param, orig_offsets, str(shard_id))

548
549
550
551
552
553
554
555
556
557
558
                loaded_weight_shard = loaded_weight.narrow(
                    output_dim, shard_offset, shard_size)
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        assert loaded_shard_id < len(self.output_sizes)
        tp_rank = get_tensor_model_parallel_rank()
        tp_size = get_tensor_model_parallel_world_size()
        if output_dim is not None:
            shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
            shard_size = self.output_sizes[loaded_shard_id] // tp_size
559
            # Special case for quantization.
560
561
562
563
564
565
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
                shard_size = shard_size // param.pack_factor
                shard_offset = shard_offset // param.pack_factor
566
                # Special case for Marlin.
567
568
                shard_size, shard_offset = adjust_marlin_shard(
                    param, shard_size, shard_offset)
gaoqiong's avatar
gaoqiong committed
569

570
571
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)
572
573
574
575
576
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

577
            if use_bitsandbytes_4bit:
578
579
580
                shard_size = loaded_weight.shape[output_dim]
                shard_offset = loaded_weight.shape[output_dim] * \
                    loaded_shard_id
581

gaoqiong's avatar
gaoqiong committed
582
583
            param_data = param_data.narrow(output_dim, shard_offset,
                                           shard_size)
584
            start_idx = tp_rank * shard_size
585
            if not is_sharded_weight:
586
587
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)
588
        # Special case for AQLM codebooks.
James Fleming's avatar
James Fleming committed
589
590
591
592
593
        elif is_metadata:
            # metadata indicates fixed size concatenated along dim 0
            shard_size = loaded_weight.shape[0]
            shard_offset = loaded_shard_id * shard_size
            param_data = param_data.narrow(0, shard_offset, shard_size)
594

595
596
597
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
598
599
                param_data, loaded_weight, loaded_shard_id)

600
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
601
602
603
604
605
606
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "MergedColumnParallelLinear, assume the weight is "
                    "the same for all partitions.")
607

gaoqiong's avatar
gaoqiong committed
608
609
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)
610

611
612
613
614
615
616
617
618
619
620
621
622
623
    def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
                                           loaded_weight: torch.Tensor):
        """
        Handle special case for models where MLP layers are already
        fused on disk. In this case, we have no shard id. This function
        determmines the shard id by splitting these layers and then calls
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """

        current_shard_offset = 0
624
        shard_offsets: list[tuple[int, int, int]] = []
625
626
627
628
629
630
631
632
        for i, output_size in enumerate(self.output_sizes):
            shard_offsets.append((i, current_shard_offset, output_size))
            current_shard_offset += output_size

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
633
634
            if isinstance(param, (PackedColumnParameter, PackedvLLMParameter
                                  )) and param.packed_dim == param.output_dim:
635
636
                shard_size, shard_offset = \
                    param.adjust_shard_indexes_for_packing(
637
638
639
640
641
642
643
644
645
646
647
648
                    shard_size=shard_size, shard_offset=shard_offset)

            loaded_weight_shard = loaded_weight.narrow(param.output_dim,
                                                       shard_offset,
                                                       shard_size)
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

    def weight_loader_v2(self,
                         param: BasevLLMParameter,
                         loaded_weight: torch.Tensor,
                         loaded_shard_id: Optional[int] = None):
        if loaded_shard_id is None:
649
650
651
652
            if isinstance(param, PerTensorScaleParameter):
                param.load_merged_column_weight(loaded_weight=loaded_weight,
                                                shard_id=0)
                return
653
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
654
                param.load_merged_column_weight(loaded_weight=loaded_weight)
655
                return
656
            # TODO: @dsikka - move to parameter.py
657
658
659
660
661
662
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id < len(self.output_sizes)

        tp_size = get_tensor_model_parallel_world_size()
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680

        if isinstance(param, BlockQuantScaleParameter):
            from vllm.model_executor.layers.quantization.fp8 import (
                Fp8LinearMethod, Fp8MoEMethod)
            assert self.quant_method is not None
            assert isinstance(self.quant_method,
                              (Fp8LinearMethod, Fp8MoEMethod))
            weight_block_size = self.quant_method.quant_config.weight_block_size
            assert weight_block_size is not None
            block_n, _ = weight_block_size[0], weight_block_size[1]
            shard_offset = (
                (sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) //
                block_n) // tp_size
            shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) //
                          block_n // tp_size)
        else:
            shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
            shard_size = self.output_sizes[loaded_shard_id] // tp_size
681
682
683
684
685
686

        param.load_merged_column_weight(loaded_weight=loaded_weight,
                                        shard_id=loaded_shard_id,
                                        shard_offset=shard_offset,
                                        shard_size=shard_size)

687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708

class QKVParallelLinear(ColumnParallelLinear):
    """Linear layers for the attention's QKV transformation.

    Linear layers for the linear transformation of the query, key, and value
    vectors in the attention layer. The weight matrix is concatenated along
    the output dimension. The layer is parallelized along the head dimension.
    When the number of key/value heads is smaller than the number of query
    heads (e.g., multi-query/grouped-query attention), the key/value head may
    be replicated while the query heads are partitioned.

    Args:
        hidden_size: input hidden state size of the transformer.
        head_size: size of each attention head.
        total_num_heads: total number of attention query heads.
        total_num_kv_heads: total number of attention key/value heads. If
                            None, assume total_num_kv_heads = total_num_heads.
        bias: If true, add bias.
        skip_bias_add: This was added to enable performance optimizations where
                       bias can be fused with other element-wise operations. we
                       skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
709
        quant_config: Quantization configure.
710
711
        prefix: The name of the layer in the state dict, including all parents
                        (e.g. model.layers.0.qkv_proj)
712
713
    """

714
715
716
717
718
719
720
721
    def __init__(self,
                 hidden_size: int,
                 head_size: int,
                 total_num_heads: int,
                 total_num_kv_heads: Optional[int] = None,
                 bias: bool = True,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
722
                 quant_config: Optional[QuantizationConfig] = None,
723
                 prefix: str = ""):
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
        self.hidden_size = hidden_size
        self.head_size = head_size
        self.total_num_heads = total_num_heads
        if total_num_kv_heads is None:
            total_num_kv_heads = total_num_heads
        self.total_num_kv_heads = total_num_kv_heads
        # Divide the weight matrix along the last dimension.
        tp_size = get_tensor_model_parallel_world_size()
        self.num_heads = divide(self.total_num_heads, tp_size)
        if tp_size >= self.total_num_kv_heads:
            self.num_kv_heads = 1
            self.num_kv_head_replicas = divide(tp_size,
                                               self.total_num_kv_heads)
        else:
            self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
            self.num_kv_head_replicas = 1
        input_size = self.hidden_size
        output_size = (self.num_heads +
                       2 * self.num_kv_heads) * tp_size * self.head_size
743
744
745
746
        self.output_sizes = [
            self.num_heads * self.head_size * tp_size,  # q_proj
            self.num_kv_heads * self.head_size * tp_size,  # k_proj
            self.num_kv_heads * self.head_size * tp_size,  # v_proj 
James Fleming's avatar
James Fleming committed
747
        ]
gaoqiong's avatar
gaoqiong committed
748

749
750
751
752
753
754
        super().__init__(input_size=input_size,
                         output_size=output_size,
                         bias=bias,
                         gather_output=False,
                         skip_bias_add=skip_bias_add,
                         params_dtype=params_dtype,
755
756
                         quant_config=quant_config,
                         prefix=prefix)
757

758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
    def _get_shard_offset_mapping(self, loaded_shard_id: str):
        shard_offset_mapping = {
            "q": 0,
            "k": self.num_heads * self.head_size,
            "v": (self.num_heads + self.num_kv_heads) * self.head_size,
            "total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size
        }
        return shard_offset_mapping.get(loaded_shard_id)

    def _get_shard_size_mapping(self, loaded_shard_id: str):
        shard_size_mapping = {
            "q": self.num_heads * self.head_size,
            "k": self.num_kv_heads * self.head_size,
            "v": self.num_kv_heads * self.head_size,
        }
        return shard_size_mapping.get(loaded_shard_id)

    def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
                                           loaded_weight: torch.Tensor):
        """
        Handle special case for models where QKV layers are already 
        fused on disk. In this case, we have no shard id. This function
        determmines the shard id by splitting these layers and then calls
        the weight loader using the shard id.

        An example of a model with these fused layers:
        https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
        """
        shard_offsets = [
            # (shard_id, shard_offset, shard_size)
            ("q", 0, self.total_num_heads * self.head_size),
            ("k", self.total_num_heads * self.head_size,
             self.total_num_kv_heads * self.head_size),
            ("v",
             (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
             self.total_num_kv_heads * self.head_size),
        ]

        for shard_id, shard_offset, shard_size in shard_offsets:
            # Special case for Quantization.
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
800
801
            if isinstance(param, (PackedColumnParameter, PackedvLLMParameter
                                  )) and param.packed_dim == param.output_dim:
802
803
                shard_size, shard_offset = \
                    param.adjust_shard_indexes_for_packing(
804
805
806
807
808
809
810
811
812
813
814
815
                    shard_size=shard_size, shard_offset=shard_offset)

            loaded_weight_shard = loaded_weight.narrow(param.output_dim,
                                                       shard_offset,
                                                       shard_size)
            self.weight_loader_v2(param, loaded_weight_shard, shard_id)

    def weight_loader_v2(self,
                         param: BasevLLMParameter,
                         loaded_weight: torch.Tensor,
                         loaded_shard_id: Optional[str] = None):
        if loaded_shard_id is None:  # special case for certain models
816
            if isinstance(param, PerTensorScaleParameter):
817
                param.load_qkv_weight(loaded_weight=loaded_weight, shard_id=0)
818
                return
819
820
            elif type(param) in (RowvLLMParameter, BasevLLMParameter):
                param.load_qkv_weight(loaded_weight=loaded_weight)
821
                return
822
            # TODO: @dsikka - move to parameter.py
823
824
825
826
827
828
829
830
831
832
833
834
835
836
            self._load_fused_module_from_checkpoint(param, loaded_weight)
            return

        assert loaded_shard_id in ["q", "k", "v"]

        shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
        shard_size = self._get_shard_size_mapping(loaded_shard_id)

        param.load_qkv_weight(loaded_weight=loaded_weight,
                              num_heads=self.num_kv_head_replicas,
                              shard_id=loaded_shard_id,
                              shard_offset=shard_offset,
                              shard_size=shard_size)

837
838
839
840
    def weight_loader(self,
                      param: Parameter,
                      loaded_weight: torch.Tensor,
                      loaded_shard_id: Optional[str] = None):
841
842
843
844
845

        # Special case for GGUF
        # initialize GGUF param after we know the quantize type
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
846
        if is_gguf_weight_type:
847
            idx_map = {"q": 0, "k": 1, "v": 2}
848
849
850
851
852
853
854
855
            if loaded_shard_id is not None:
                param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
            else:
                param.shard_weight_type = {
                    k: loaded_weight.item()
                    for k in idx_map
                }
856
857
            return

858
859
860
        if is_gguf_weight:
            tp_size = get_tensor_model_parallel_world_size()
            tp_rank = get_tensor_model_parallel_rank()
861

862
863
864
865
            output_dim = getattr(param, "output_dim", None)
            shard_size = loaded_weight.size(output_dim) // tp_size
            start_idx = tp_rank * shard_size

866
867
868
869
870
871
872
873
874
            if loaded_shard_id is not None:
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)
                param.shard_id.append(loaded_shard_id)
                param.shard_id_map[loaded_shard_id] = len(param.data_container)
                param.data_container.append(loaded_weight)
                if len(param.data_container) == 3:
                    self.qweight = param.materialize_nested()
                return
875

876
877
        param_data = param.data
        output_dim = getattr(param, "output_dim", None)
878
        # Special case for AQLM codebooks.
James Fleming's avatar
James Fleming committed
879
        is_metadata = getattr(param, "is_metadata", False)
880

881
882
        # Special case for per-tensor scales in fused case.
        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
883

884
        if loaded_shard_id is None:
885
886
            # Loaded weight is already fused on disk (qkv).
            # (e.g., Phi-3's qkv_proj).
887
            if output_dim is None:
888
                if needs_scalar_to_array:
889
890
                    param_data, loaded_weight = adjust_scalar_to_fused_array(
                        param_data, loaded_weight, 0)
891

892
893
894
895
896
897
898
899
900
901
902
                assert param_data.shape == loaded_weight.shape
                param_data.copy_(loaded_weight)
                return
            shard_offsets = [
                # (shard_id, shard_offset, shard_size)
                ("q", 0, self.total_num_heads * self.head_size),
                ("k", self.total_num_heads * self.head_size,
                 self.total_num_kv_heads * self.head_size),
                ("v", (self.total_num_heads + self.total_num_kv_heads) *
                 self.head_size, self.total_num_kv_heads * self.head_size),
            ]
903
904
905
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)

906
907
            packed_dim = getattr(param, "packed_dim", None)
            for shard_id, shard_offset, shard_size in shard_offsets:
908
                # Special case for Quantized Weights.
909
910
911
912
913
                # If quantized, we need to adjust the offset and size to account
                # for the packing.
                if packed_dim == output_dim:
                    shard_size = shard_size // param.pack_factor
                    shard_offset = shard_offset // param.pack_factor
914

915
                    # Special case for Marlin.
916
917
918
                    shard_size, shard_offset = adjust_marlin_shard(
                        param, shard_size, shard_offset)

919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
                if use_bitsandbytes_4bit:
                    orig_qkv_offsets = {
                        "q": (0, self.total_num_heads * self.head_size),
                        "k": (self.total_num_heads * self.head_size,
                              self.total_num_kv_heads * self.head_size),
                        "v":
                        ((self.total_num_heads + self.total_num_kv_heads) *
                         self.head_size,
                         self.total_num_kv_heads * self.head_size),
                        "total":
                        ((self.total_num_heads + 2 * self.total_num_kv_heads) *
                         self.head_size, 0)
                    }

                    shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
                        param, orig_qkv_offsets, shard_id)

936
937
938
939
940
941
942
                loaded_weight_shard = loaded_weight.narrow(
                    output_dim, shard_offset, shard_size)
                self.weight_loader(param, loaded_weight_shard, shard_id)
            return

        tp_rank = get_tensor_model_parallel_rank()
        assert loaded_shard_id in ["q", "k", "v"]
943
944

        # If output dim is defined, use the default loading process.
945
946
947
948
949
950
951
952
953
954
955
        if output_dim is not None:
            if loaded_shard_id == "q":
                shard_offset = 0
                shard_size = self.num_heads * self.head_size
            elif loaded_shard_id == "k":
                shard_offset = self.num_heads * self.head_size
                shard_size = self.num_kv_heads * self.head_size
            elif loaded_shard_id == "v":
                shard_offset = (self.num_heads +
                                self.num_kv_heads) * self.head_size
                shard_size = self.num_kv_heads * self.head_size
956
            # Special case for Quantized Weights.
957
958
959
960
961
962
            # If quantized, we need to adjust the offset and size to account
            # for the packing.
            packed_dim = getattr(param, "packed_dim", None)
            if packed_dim == output_dim:
                shard_size = shard_size // param.pack_factor
                shard_offset = shard_offset // param.pack_factor
963

964
                # Special case for Marlin.
965
966
967
                shard_size, shard_offset = adjust_marlin_shard(
                    param, shard_size, shard_offset)

968
969
            use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
                                            False)
970
971
972
973
974
            is_sharded_weight = getattr(param, "is_sharded_weight", False)
            # bitsandbytes loads the weights of the specific portion
            # no need to narrow
            is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit

975
            if use_bitsandbytes_4bit:
976
977
978
979
980
981
982
983
984
985
986
                orig_qkv_offsets = {
                    "q": (0, self.num_heads * self.head_size),
                    "k": (self.num_heads * self.head_size,
                          self.num_kv_heads * self.head_size),
                    "v":
                    ((self.num_heads + self.num_kv_heads) * self.head_size,
                     self.num_kv_heads * self.head_size),
                    "total":
                    ((self.num_heads + 2 * self.num_kv_heads) * self.head_size,
                     0)
                }
987
                shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
988
                    param, orig_qkv_offsets, loaded_shard_id)
gaoqiong's avatar
gaoqiong committed
989
990

            param_data = param_data.narrow(output_dim, shard_offset,
zhuwenwen's avatar
zhuwenwen committed
991
                                           shard_size)
zhuwenwen's avatar
zhuwenwen committed
992
            if loaded_shard_id == "q":
993
994
995
                shard_id = tp_rank
            else:
                shard_id = tp_rank // self.num_kv_head_replicas
996
            start_idx = shard_id * shard_size
997

998
            if not is_sharded_weight:
999
1000
1001
                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                     shard_size)

1002
        # Special case for for AQLM codebooks.
James Fleming's avatar
James Fleming committed
1003
1004
1005
1006
1007
1008
        elif is_metadata:
            # metadata indicates fixed size concatenated along dim 0
            shard_size = loaded_weight.shape[0]
            shard_index = ["q", "k", "v"].index(loaded_shard_id)
            param_data = param_data.narrow(0, shard_index * shard_size,
                                           shard_size)
1009
1010
1011
        # Special case for per-tensor scales in fused case.
        elif needs_scalar_to_array:
            param_data, loaded_weight = adjust_scalar_to_fused_array(
1012
                param_data, loaded_weight, loaded_shard_id)
1013
        else:
CHU Tianxiang's avatar
CHU Tianxiang committed
1014
1015
1016
1017
1018
1019
            ignore_warning = getattr(param, "ignore_warning", False)
            if not ignore_warning:
                logger.warning(
                    "Loading a weight without `output_dim` attribute in "
                    "QKVParallelLinear, assume the weight is the same "
                    "for all partitions.")
gaoqiong's avatar
gaoqiong committed
1020
1021
1022

        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)
1023
1024


1025
class RowParallelLinear(LinearBase):
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
    """Linear layer with row parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its first dimension and X along its second dimension as:
               -   -
              | A_1 |
              | .   |
          A = | .   |        X = [X_1, ..., X_p]
              | .   |
              | A_p |
               -   -
    Arguments:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias. Note that bias is not parallelized.
        input_is_parallel: If true, we assume that the input is already
                           split across the GPUs and we do not split
                           again.
        skip_bias_add: This was added to enable performance optimization where
                       bias can be fused with other element-wise operations.
                       We skip adding bias but instead return it.
        params_dtype: Data type for the parameters.
1048
        quant_config: Quantization configure.
1049
1050
    """

1051
1052
1053
1054
1055
1056
1057
1058
    def __init__(self,
                 input_size: int,
                 output_size: int,
                 bias: bool = True,
                 input_is_parallel: bool = True,
                 skip_bias_add: bool = False,
                 params_dtype: Optional[torch.dtype] = None,
                 reduce_results: bool = True,
1059
                 quant_config: Optional[QuantizationConfig] = None,
1060
                 prefix: str = ""):
1061
        super().__init__(input_size, output_size, skip_bias_add, params_dtype,
1062
                         quant_config, prefix)
1063

1064
1065
1066
1067
        self.input_is_parallel = input_is_parallel
        self.reduce_results = reduce_results

        # Divide the weight matrix along the last dimension.
1068
        self.tp_rank = get_tensor_model_parallel_rank()
1069
1070
        self.tp_size = get_tensor_model_parallel_world_size()
        self.input_size_per_partition = divide(input_size, self.tp_size)
1071
        assert self.quant_method is not None
1072

1073
1074
1075
1076
1077
1078
1079
        self.quant_method.create_weights(
            layer=self,
            input_size_per_partition=self.input_size_per_partition,
            output_partition_sizes=[self.output_size],
            input_size=self.input_size,
            output_size=self.output_size,
            params_dtype=self.params_dtype,
1080
1081
            weight_loader=(
                self.weight_loader_v2 if self.quant_method.__class__.__name__
1082
                in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
1083
1084
1085
1086
1087
1088
        if not reduce_results and (bias and not skip_bias_add):
            raise ValueError("When not reduce the results, adding bias to the "
                             "results can lead to incorrect results")

        if bias:
            self.bias = Parameter(
1089
                torch.empty(self.output_size, dtype=params_dtype))
1090
1091
1092
1093
1094
1095
1096
1097
1098
            set_weight_attrs(self.bias, {
                "output_dim": 0,
                "weight_loader": self.weight_loader,
            })
        else:
            self.register_parameter("bias", None)

    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
        tp_rank = get_tensor_model_parallel_rank()
1099
        tp_size = get_tensor_model_parallel_world_size()
1100
        input_dim = getattr(param, "input_dim", None)
1101
        use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1102
1103
1104
1105
        is_sharded_weight = getattr(param, "is_sharded_weight", False)
        # bitsandbytes loads the weights of the specific portion
        # no need to narrow
        is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
1106
1107
1108
1109
1110
1111
1112
1113
1114

        # Special case for GGUF
        is_gguf_weight = getattr(param, "is_gguf_weight", False)
        is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
        if is_gguf_weight_type:
            param.weight_type = loaded_weight.item()

        # Materialize GGUF UninitializedParameter
        if is_gguf_weight and isinstance(param, UninitializedParameter):
1115
1116
1117
1118
            weight_shape = list(loaded_weight.shape)
            if input_dim:
                weight_shape[input_dim] = weight_shape[input_dim] // tp_size
            param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
1119

1120
        param_data = param.data
1121
        if input_dim is not None and not is_sharded_weight:
1122
1123
1124
1125
            shard_size = param_data.shape[input_dim]
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(input_dim, start_idx,
                                                 shard_size)
1126

1127
1128
1129
        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
1130
1131
            loaded_weight = loaded_weight.reshape(1)

1132
1133
1134
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

1135
1136
    def weight_loader_v2(self, param: BasevLLMParameter,
                         loaded_weight: torch.Tensor):
1137
1138
1139
1140
1141
1142
1143

        # Special case for loading scales off disk, which often do not
        # have a shape (such as in the case of AutoFP8).
        if len(loaded_weight.shape) == 0:
            assert loaded_weight.numel() == 1
            loaded_weight = loaded_weight.reshape(1)

1144
1145
        param.load_row_parallel_weight(loaded_weight=loaded_weight)

1146
    def forward(self, input_) -> tuple[torch.Tensor, Optional[Parameter]]:
1147
1148
1149
1150
1151
1152
1153
1154
1155
        if self.input_is_parallel:
            input_parallel = input_
        else:
            tp_rank = get_tensor_model_parallel_rank()
            splitted_input = split_tensor_along_last_dim(
                input_, num_partitions=self.tp_size)
            input_parallel = splitted_input[tp_rank].contiguous()

        # Matrix multiply.
1156
        assert self.quant_method is not None
1157
1158
1159
1160
1161
1162
        # Only fuse bias add into GEMM for rank 0 (this ensures that
        # bias will not get added more than once in TP>1 case)
        bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
        output_parallel = self.quant_method.apply(self,
                                                  input_parallel,
                                                  bias=bias_)
1163
        if self.reduce_results and self.tp_size > 1:
1164
            output = tensor_model_parallel_all_reduce(output_parallel)
1165
        else:
1166
1167
1168
            output = output_parallel

        output_bias = self.bias if self.skip_bias_add else None
1169
1170

        return output, output_bias
1171
1172
1173
1174
1175
1176
1177

    def extra_repr(self) -> str:
        s = f"input_features={self.input_size_per_partition}"
        s += f", output_features={self.output_size}"
        s += f", bias={self.bias is not None}"
        s += f", tp_size={self.tp_size}"
        s += f", reduce_results={self.reduce_results}"
1178
        return s