parameter.py 23.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Hashable
5
from fractions import Fraction
6
from typing import Callable, Optional, Union
7
from weakref import WeakValueDictionary
8
9
10
11

import torch
from torch.nn import Parameter

12
13
from vllm.distributed import (get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size)
14
from vllm.logger import init_logger
15
from vllm.utils import is_torch_equal_or_newer
16
17
18
19

__all__ = [
    "BasevLLMParameter", "PackedvLLMParameter", "PerTensorScaleParameter",
    "ModelWeightParameter", "ChannelQuantScaleParameter",
20
    "GroupQuantScaleParameter", "PackedColumnParameter", "RowvLLMParameter"
21
22
23
24
25
26
27
28
29
30
31
32
]

logger = init_logger(__name__)


class BasevLLMParameter(Parameter):
    """
    Base parameter for vLLM linear layers. Extends the torch.nn.parameter
    by taking in a linear weight loader. Will copy the loaded weight
    into the parameter when the provided weight loader is called.
    """

33
    def __new__(cls, data: Optional[torch.Tensor], **kwargs):
34
35
36
37
38
39
40
41
42
43
44
45
46

        return super().__new__(cls, data=data, requires_grad=False)

    def __init__(self, data: torch.Tensor, weight_loader: Callable):
        """
        Initialize the BasevLLMParameter

        :param data: torch tensor with the parameter data
        :param weight_loader: weight loader callable

        :returns: a torch.nn.parameter
        """

47
48
49
50
51
52
53
54
55
        # During weight loading, we often do something like:
        # narrowed_tensor = param.data.narrow(0, offset, len)
        # narrowed_tensor.copy_(real_weight)
        # expecting narrowed_tensor and param.data to share the same storage.
        # However, on TPUs, narrowed_tensor will lazily propagate to the base
        # tensor, which is param.data, leading to the redundant memory usage.
        # This sometimes causes OOM errors during model loading. To avoid this,
        # we sync the param tensor after its weight loader is called.
        from vllm.platforms import current_platform
56
57
58
        if current_platform.use_sync_weight_loader():
            weight_loader = current_platform.make_synced_weight_loader(
                weight_loader)
59

60
        self._weight_loader = weight_loader
61
62
        self.tp_rank = get_tensor_model_parallel_rank()
        self.tp_size = get_tensor_model_parallel_world_size()
63
64

    @property
65
66
67
68
69
70
71
72
    def weight_loader(self) -> Callable:
        # NOTE(@ksayers) some models such as mamba_mixer2 override the
        # weight loader to support custom loading. In the future, model-specific
        # weight loading should be implemented via Model.load_weights. In the
        # meantime, support deleting and overriding `weight_loader`` attribute
        if self._weight_loader is None:
            raise AttributeError(f"{self.__class__.__name__} weight_loader "
                                 "attribute has been deleted")
73
74
        return self._weight_loader

75
76
77
78
79
80
81
82
    @weight_loader.setter
    def weight_loader(self, value: Callable):
        self._weight_loader = value

    @weight_loader.deleter
    def weight_loader(self):
        self._weight_loader = None  # type: ignore[assignment]

83
84
85
86
87
    def _is_1d_and_scalar(self, loaded_weight: torch.Tensor):
        cond1 = self.data.ndim == 1 and self.data.numel() == 1
        cond2 = loaded_weight.ndim == 0 and loaded_weight.numel() == 1
        return (cond1 and cond2)

88
    def _assert_and_load(self, loaded_weight: torch.Tensor):
89
90
        assert (self.data.shape == loaded_weight.shape
                or self._is_1d_and_scalar(loaded_weight))
91
92
93
94
95
96
97
98
99
100
101
102
103
104
        self.data.copy_(loaded_weight)

    def load_column_parallel_weight(self, loaded_weight: torch.Tensor):
        self._assert_and_load(loaded_weight)

    def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
        self._assert_and_load(loaded_weight)

    def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
        self._assert_and_load(loaded_weight)

    def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
        self._assert_and_load(loaded_weight)

105
106
107
108
109
110
111
112
113
114
115
    def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
        if isinstance(shard_id, int):
            return shard_id

        # if not int, assume shard_id for qkv
        # map to int and return
        qkv_idxs = {"q": 0, "k": 1, "v": 2}
        assert isinstance(shard_id, str)
        assert shard_id in qkv_idxs
        return qkv_idxs[shard_id]

116
117
    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None):
118
119
120
121
122
123
124
125
126
        if not is_torch_equal_or_newer("2.8.0"):
            logger.warning_once(
                "Torch %s detected (<2.8.0): returning NotImplemented in "
                "BasevLLMParameter.__torch_function__ to avoid potential "
                "TorchDynamo issues.",
                torch.__version__,
            )
            return NotImplemented

127
128
129
130
        if kwargs is None:
            kwargs = {}
        return super().__torch_function__(func, types, args, kwargs)

131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153

class _ColumnvLLMParameter(BasevLLMParameter):
    """
    Private class defining weight loading functionality 
    (load_merged_column_weight, load_qkv_weight)
    for parameters being loaded into linear layers with column
    parallelism. This includes QKV and MLP layers which are
    not already fused on disk. Requires an output dimension 
    to be defined. Called within the weight loader of
    each of the column parallel linear layers.
    """

    def __init__(self, output_dim: int, **kwargs):
        self._output_dim = output_dim
        super().__init__(**kwargs)

    @property
    def output_dim(self):
        return self._output_dim

    def load_column_parallel_weight(self, loaded_weight: torch.Tensor):
        shard_size = self.data.shape[self.output_dim]
        loaded_weight = loaded_weight.narrow(self.output_dim,
154
155
                                             self.tp_rank * shard_size,
                                             shard_size)
156
157
158
159
160
161
162
        assert self.data.shape == loaded_weight.shape
        self.data.copy_(loaded_weight)

    def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):

        shard_offset = kwargs.get("shard_offset")
        shard_size = kwargs.get("shard_size")
163

164
        # TODO: move these to PackedColumnParameter and PackedvLLMParameter
165
166
        if isinstance(
                self,
167
168
            (PackedColumnParameter,
             PackedvLLMParameter)) and self.packed_dim == self.output_dim:
169
170
171
172
173
174
175
176
            shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
                shard_offset=shard_offset, shard_size=shard_size)

        param_data = self.data

        param_data = param_data.narrow(self.output_dim, shard_offset,
                                       shard_size)
        loaded_weight = loaded_weight.narrow(self.output_dim,
177
178
                                             self.tp_rank * shard_size,
                                             shard_size)
179
180
181
182
183
184
185
186
187
188
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

    def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):

        shard_offset = kwargs.get("shard_offset")
        shard_size = kwargs.get("shard_size")
        shard_id = kwargs.get("shard_id")
        num_heads = kwargs.get("num_heads")

189
        # TODO: move these to PackedColumnParameter and PackedvLLMParameter
190
191
        if isinstance(
                self,
192
193
            (PackedColumnParameter,
             PackedvLLMParameter)) and self.output_dim == self.packed_dim:
194
195
196
197
            shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
                shard_offset=shard_offset, shard_size=shard_size)

        param_data = self.data
198
199
        shard_id = (self.tp_rank if shard_id == "q" else self.tp_rank //
                    num_heads)
200
201
202
203
204
205
206
207
208
        param_data = param_data.narrow(self.output_dim, shard_offset,
                                       shard_size)
        loaded_weight = loaded_weight.narrow(self.output_dim,
                                             shard_id * shard_size, shard_size)

        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)


209
class RowvLLMParameter(BasevLLMParameter):
210
    """
211
212
213
214
    Parameter class defining weight_loading functionality
    (load_row_parallel_weight) for parameters being loaded
    into linear layers with row parallel functionality.
    Requires an input_dim to be defined.
215
216
217
218
219
220
221
222
223
224
225
226
227
    """

    def __init__(self, input_dim: int, **kwargs):
        self._input_dim = input_dim
        super().__init__(**kwargs)

    @property
    def input_dim(self):
        return self._input_dim

    def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
        shard_size = self.data.shape[self.input_dim]
        loaded_weight = loaded_weight.narrow(self.input_dim,
228
229
                                             self.tp_rank * shard_size,
                                             shard_size)
230
231
232
233
234
235
236
237

        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)

        assert self.data.shape == loaded_weight.shape
        self.data.copy_(loaded_weight)


238
239
240
241
242
243
244
245
246
class ModelWeightParameter(_ColumnvLLMParameter, RowvLLMParameter):
    """
    Parameter class for linear layer weights. Uses both column and
    row parallelism.
    """
    pass


class GroupQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter):
247
248
    """
    Parameter class for weight scales loaded for weights with
249
    grouped quantization. Uses both column and row parallelism.
250
251
252
253
254
255
256
    """
    pass


class ChannelQuantScaleParameter(_ColumnvLLMParameter):
    """
    Parameter class for weight scales loaded for weights with
257
    channel-wise quantization. Equivalent to _ColumnvLLMParameter.
258
259
260
261
262
263
264
265
266
    """
    pass


class PerTensorScaleParameter(BasevLLMParameter):
    """
    Parameter class for scales where the number of scales is
    equivalent to the number of logical matrices in fused linear
    layers (e.g. for QKV, there are 3 scales loaded from disk).
267
    This is relevant to weights with per-tensor quantization.
268
269
270
271
272
273
274
275
276
277
278
    Adds functionality to map the scalers to a shard during
    weight loading. 

    Note: additional parameter manipulation may be handled 
    for each quantization config specifically, within 
    process_weights_after_loading 
    """

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

279
280
281
282
283
    # For row parallel layers, no sharding needed
    # load weight into parameter as is
    def load_row_parallel_weight(self, *args, **kwargs):
        super().load_row_parallel_weight(*args, **kwargs)

284
285
286
287
288
289
290
    def load_merged_column_weight(self, *args, **kwargs):
        self._load_into_shard_id(*args, **kwargs)

    def load_qkv_weight(self, *args, **kwargs):
        self._load_into_shard_id(*args, **kwargs)

    def load_column_parallel_weight(self, *args, **kwargs):
291
        super().load_row_parallel_weight(*args, **kwargs)
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313

    def _load_into_shard_id(self, loaded_weight: torch.Tensor,
                            shard_id: Union[str, int], **kwargs):
        """
        Slice the parameter data based on the shard id for 
        loading.
        """

        param_data = self.data
        shard_id = self._shard_id_as_int(shard_id)

        # AutoFP8 scales do not have a shape
        # compressed-tensors scales do have a shape
        if len(loaded_weight.shape) != 0:
            assert loaded_weight.shape[0] == 1
            loaded_weight = loaded_weight[0]

        param_data = param_data[shard_id]
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)


314
315
316
317
318
319
320
321
class PackedColumnParameter(_ColumnvLLMParameter):
    """
    Parameter for model parameters which are packed on disk
    and support column parallelism only. See PackedvLLMParameter
    for more details on the packed properties.
    """

    def __init__(self,
322
                 packed_factor: Union[int, Fraction],
323
324
                 packed_dim: int,
                 marlin_tile_size: Optional[int] = None,
325
                 bitblas_tile_size: Optional[int] = None,
326
327
328
329
                 **kwargs):
        self._packed_factor = packed_factor
        self._packed_dim = packed_dim
        self._marlin_tile_size = marlin_tile_size
330
        self._bitblas_tile_size = bitblas_tile_size
331
332
333
334
335
336
337
338
339
340
341
342
343
344
        super().__init__(**kwargs)

    @property
    def packed_dim(self):
        return self._packed_dim

    @property
    def packed_factor(self):
        return self._packed_factor

    @property
    def marlin_tile_size(self):
        return self._marlin_tile_size

345
346
347
348
    @property
    def bitblas_tile_size(self):
        return self._bitblas_tile_size

349
350
351
352
353
    def adjust_shard_indexes_for_packing(self, shard_size, shard_offset):
        return _adjust_shard_indexes_for_packing(
            shard_size=shard_size,
            shard_offset=shard_offset,
            packed_factor=self.packed_factor,
354
355
            marlin_tile_size=self.marlin_tile_size,
            bitblas_tile_size=self.bitblas_tile_size)
356
357


358
359
360
361
362
363
364
class PackedvLLMParameter(ModelWeightParameter):
    """
    Parameter for model weights which are packed on disk.
    Example: GPTQ Marlin weights are int4 or int8, packed into int32.
    Extends the ModelWeightParameter to take in the
    packed factor, the packed dimension, and optionally, marlin
    tile size for marlin kernels. Adjusts the shard_size and 
365
    shard_offset for fused linear layers model weight loading
366
367
368
369
    by accounting for packing and optionally, marlin tile size.
    """

    def __init__(self,
370
                 packed_factor: Union[int, Fraction],
371
372
                 packed_dim: int,
                 marlin_tile_size: Optional[int] = None,
373
                 bitblas_tile_size: Optional[int] = None,
374
375
376
                 **kwargs):
        self._packed_factor = packed_factor
        self._packed_dim = packed_dim
377
        self._marlin_tile_size = marlin_tile_size
378
        self._bitblas_tile_size = bitblas_tile_size
379
380
381
382
383
384
385
386
387
388
389
        super().__init__(**kwargs)

    @property
    def packed_dim(self):
        return self._packed_dim

    @property
    def packed_factor(self):
        return self._packed_factor

    @property
390
391
    def marlin_tile_size(self):
        return self._marlin_tile_size
392

393
394
395
396
    @property
    def bitblas_tile_size(self):
        return self._bitblas_tile_size

397
    def adjust_shard_indexes_for_packing(self, shard_size, shard_offset):
398
399
400
401
        return _adjust_shard_indexes_for_packing(
            shard_size=shard_size,
            shard_offset=shard_offset,
            packed_factor=self.packed_factor,
402
403
            marlin_tile_size=self.marlin_tile_size,
            bitblas_tile_size=self.bitblas_tile_size)
404
405


406
407
408
409
410
411
412
413
414
class BlockQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter):
    """
    Parameter class for weight scales loaded for weights with
    block-wise quantization. Uses both column and row parallelism.
    """

    pass


415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
class SharedWeightParameter(BasevLLMParameter):
    """
    Parameter for weights with many shared tensors across a model

    For example, when applying transforms to the "gate" and "up" partitions of
    `MergedColumnParallelLinear`, the transform weights must stay separate
    tensors in order to allow for tensor memory sharing between layers.
    """
    # global registry for sharing tensors based on passed `data_key`
    # this dict holds weaksrefs to avoid memory leak after model cleanup
    tensors_registry: WeakValueDictionary = WeakValueDictionary()

    # local container for strong references to shared tensors
    # this set compensates for the fact that torch.nn.Parameter
    # and Parameter subclasses do not hold reliable references to tensors
    local_tensors: set[torch.Tensor]

    # dictionary mapping partition indices to associated parameters
    partitions: dict[int, Union[ModelWeightParameter, Parameter]]

    def __new__(cls, **kwargs):
        return super().__new__(cls, data=None, **kwargs)

    def __init__(self, input_dim: int = 1, output_dim: int = 0, **kwargs):
        weight_loader: Callable = kwargs.get(
            "weight_loader")  # type: ignore[assignment]
        super().__init__(data=None, weight_loader=weight_loader)

        self.local_tensors = set()
        self.partitions = {}
        self.kwargs = {
            "input_dim": input_dim,
            "output_dim": output_dim,
            "weight_loader": self._fake_weight_loader
        }

        if self.tp_size > 1:
            raise NotImplementedError(f"{self.__class__.__name__} does not "
                                      "currently support tensor parallelism")

    def add_partition(self, index: int, data_key: Hashable, *args, **kwargs):
        """
        Add a partition to the weight parameter. Partitions whose `data_key`
        is the same will share tensor data

        :param index: index of partition to add
        :param data_key: hashable key used to key shared tensors
        :param *args: arguments for `torch.empty`
        :param **kwargs: keyword arguments for `torch.empty`
        """
        # load (shared) tensor using `data_key`
        if data_key not in self.tensors_registry:
            data = torch.empty(*args, **kwargs)
            self.tensors_registry[data_key] = data
        else:
            data = self.tensors_registry[data_key]

        # create associated model parameter
        self.partitions[index] = ModelWeightParameter(
            data=data, **self.kwargs)  # type: ignore[arg-type]

        # hold local reference, since ModelWeightParameter does not
        # see https://github.com/pytorch/pytorch/issues/75932
        self.local_tensors.add(data)

    def load_column_parallel_weight(self, loaded_weight: torch.Tensor):
        assert len(self.partitions) == 1 and 0 in self.partitions
        partition = self.partitions[0]

        ModelWeightParameter.load_column_parallel_weight(
            partition, loaded_weight)

    def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
        assert len(self.partitions) == 1 and 0 in self.partitions
        partition = self.partitions[0]

        ModelWeightParameter.load_row_parallel_weight(partition, loaded_weight)

    def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
        partition_id = kwargs.pop("shard_id")
        partition_id = self._shard_id_as_int(partition_id)
        partition = self.partitions[partition_id]

        input_dim = self.kwargs.get("input_dim")
        shard_size = partition.data.size(input_dim) // self.tp_size
        shard_offset = self.tp_rank * shard_size

        ModelWeightParameter.load_merged_column_weight(
            partition,
            loaded_weight,
            shard_offset=shard_offset,
            shard_size=shard_size)

    def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
        partition_id = self._shard_id_as_int(kwargs.pop("shard_id"))
        partition = self.partitions[partition_id]

        input_dim = self.kwargs.get("input_dim")
        shard_size = partition.data.size(input_dim) // self.tp_size
        shard_offset = self.tp_rank * shard_size
        shard_id = "q"  # fake first partition
        num_heads = kwargs.get("num_heads")

        ModelWeightParameter.load_qkv_weight(
            partition,
            loaded_weight,
            shard_offset=shard_offset,
            shard_size=shard_size,
            shard_id=shard_id,
            num_heads=num_heads,
        )

    def process_weights_after_loading(self):
        for key in self.partitions:
            self.partitions[key] = torch.nn.Parameter(
                data=self.partitions[key].data, requires_grad=False)

    @property
    def data(self):
        raise ValueError("Accessing `data` of a "
                         "`PartitionedModelWeightParameter` is not allowed. "
                         "Instead, use `get_partition` to get the weight of "
                         "the particular partition you want to access")

    def _fake_weight_loader(self, param: BasevLLMParameter,
                            loaded_weight: torch.Tensor,
                            loaded_weight_shard_id: Optional[Union[str, int]]):
        raise ValueError("When loading partition weights of "
                         f"{self.__class__.__name__}, use methods provided by "
                         f"{self.__class__.__name__}, not partition loader")


547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
def permute_param_layout_(param: BasevLLMParameter, input_dim: int,
                          output_dim: int, **kwargs) -> BasevLLMParameter:
    """
    Permute a parameter's layout to the specified input and output dimensions, 
    useful for forcing the parameter into a known layout, for example, if I need
    a packed (quantized) weight matrix to be in the layout 
        {input_dim = 0, output_dim = 1, packed_dim = 0}
    then I can call:
        permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
    to ensure x is in the correct layout (permuting it to the correct layout if 
    required, asserting if it cannot get it to the correct layout)
    """

    curr_input_dim = getattr(param, "input_dim", None)
    curr_output_dim = getattr(param, "output_dim", None)

    if curr_input_dim is None or curr_output_dim is None:
        assert param.data.dim() == 2,\
            "permute_param_layout_ only supports 2D parameters when either "\
            "input_dim or output_dim is not set"

    # if one of the dimensions is not set, set it to the opposite of the other
    #  we can only do this since we asserted the parameter is 2D above
    if curr_input_dim is None:
        assert curr_output_dim is not None,\
            "either input or output dim must be set"
        curr_input_dim = (curr_output_dim + 1) % 2
    if curr_output_dim is None:
        assert curr_input_dim is not None,\
            "either input or output dim must be set"
        curr_output_dim = (curr_input_dim + 1) % 2

    # create permutation from the current layout to the layout with
    # self.input_dim at input_dim and self.output_dim at output_dim preserving
    # other dimensions
    perm = [
        i for i in range(param.data.dim())
        if i not in [curr_input_dim, curr_output_dim]
    ]
    perm.insert(input_dim, curr_input_dim)
    perm.insert(output_dim, curr_output_dim)

    if "packed_dim" in kwargs:
        assert hasattr(param, "packed_dim") and\
            param.packed_dim == perm[kwargs["packed_dim"]],\
            "permute_param_layout_ currently doesn't support repacking"

    param.data = param.data.permute(*perm)
    if hasattr(param, "_input_dim"):
        param._input_dim = input_dim
    if hasattr(param, "_output_dim"):
        param._output_dim = output_dim
    if "packed_dim" in kwargs and hasattr(param, "_packed_dim"):
        param._packed_dim = kwargs["packed_dim"]

    return param


605
606
607
608
609
def _adjust_shard_indexes_for_marlin(shard_size, shard_offset,
                                     marlin_tile_size):
    return shard_size * marlin_tile_size, shard_offset * marlin_tile_size


610
611
612
613
614
def _adjust_shard_indexes_for_bitblas(shard_size, shard_offset,
                                      bitblas_tile_size):
    return shard_size // bitblas_tile_size, shard_offset // bitblas_tile_size


615
def _adjust_shard_indexes_for_packing(shard_size, shard_offset, packed_factor,
616
                                      marlin_tile_size, bitblas_tile_size):
617
618
619
620
621
622
623
    shard_size = shard_size // packed_factor
    shard_offset = shard_offset // packed_factor
    if marlin_tile_size is not None:
        return _adjust_shard_indexes_for_marlin(
            shard_size=shard_size,
            shard_offset=shard_offset,
            marlin_tile_size=marlin_tile_size)
624
625
626
627
628
629
    elif bitblas_tile_size is not None:
        return _adjust_shard_indexes_for_bitblas(
            shard_size=shard_size,
            shard_offset=shard_offset,
            bitblas_tile_size=bitblas_tile_size)

630
    return shard_size, shard_offset