parameter.py 17 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from fractions import Fraction
5
6
7
8
9
10
11
from typing import Callable, Optional, Union

import torch
from torch.nn import Parameter

from vllm.distributed import get_tensor_model_parallel_rank
from vllm.logger import init_logger
12
from vllm.model_executor.utils import _make_synced_weight_loader
13
from vllm.model_executor.layers.dp_attention import get_moe_tp_rank, get_moe_tp_size
14
15
16
17

__all__ = [
    "BasevLLMParameter", "PackedvLLMParameter", "PerTensorScaleParameter",
    "ModelWeightParameter", "ChannelQuantScaleParameter",
18
    "GroupQuantScaleParameter", "PackedColumnParameter", "RowvLLMParameter"
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
]

logger = init_logger(__name__)


class BasevLLMParameter(Parameter):
    """
    Base parameter for vLLM linear layers. Extends the torch.nn.parameter
    by taking in a linear weight loader. Will copy the loaded weight
    into the parameter when the provided weight loader is called.
    """

    def __new__(cls, data: torch.Tensor, **kwargs):

        return super().__new__(cls, data=data, requires_grad=False)

    def __init__(self, data: torch.Tensor, weight_loader: Callable):
        """
        Initialize the BasevLLMParameter

        :param data: torch tensor with the parameter data
        :param weight_loader: weight loader callable

        :returns: a torch.nn.parameter
        """

45
46
47
48
49
50
51
52
53
54
55
56
        # During weight loading, we often do something like:
        # narrowed_tensor = param.data.narrow(0, offset, len)
        # narrowed_tensor.copy_(real_weight)
        # expecting narrowed_tensor and param.data to share the same storage.
        # However, on TPUs, narrowed_tensor will lazily propagate to the base
        # tensor, which is param.data, leading to the redundant memory usage.
        # This sometimes causes OOM errors during model loading. To avoid this,
        # we sync the param tensor after its weight loader is called.
        from vllm.platforms import current_platform
        if current_platform.is_tpu():
            weight_loader = _make_synced_weight_loader(weight_loader)

57
58
59
60
61
62
        self._weight_loader = weight_loader

    @property
    def weight_loader(self):
        return self._weight_loader

63
64
65
66
67
    def _is_1d_and_scalar(self, loaded_weight: torch.Tensor):
        cond1 = self.data.ndim == 1 and self.data.numel() == 1
        cond2 = loaded_weight.ndim == 0 and loaded_weight.numel() == 1
        return (cond1 and cond2)

68
    def _assert_and_load(self, loaded_weight: torch.Tensor):
69
70
        assert (self.data.shape == loaded_weight.shape
                or self._is_1d_and_scalar(loaded_weight))
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
        self.data.copy_(loaded_weight)

    def load_column_parallel_weight(self, loaded_weight: torch.Tensor):
        self._assert_and_load(loaded_weight)

    def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
        self._assert_and_load(loaded_weight)

    def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
        self._assert_and_load(loaded_weight)

    def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
        self._assert_and_load(loaded_weight)


class _ColumnvLLMParameter(BasevLLMParameter):
    """
    Private class defining weight loading functionality 
    (load_merged_column_weight, load_qkv_weight)
    for parameters being loaded into linear layers with column
    parallelism. This includes QKV and MLP layers which are
    not already fused on disk. Requires an output dimension 
    to be defined. Called within the weight loader of
    each of the column parallel linear layers.
    """

    def __init__(self, output_dim: int, **kwargs):
        self._output_dim = output_dim
        super().__init__(**kwargs)
王敏's avatar
王敏 committed
100
        self.expect_tp_size = -1
101
        self.enable_dp_attn_moe = False
王敏's avatar
王敏 committed
102

103
104
105
106
107
108
109

    @property
    def output_dim(self):
        return self._output_dim

    def load_column_parallel_weight(self, loaded_weight: torch.Tensor):
        tp_rank = get_tensor_model_parallel_rank()
王敏's avatar
王敏 committed
110
111
        if self.expect_tp_size == 1:
            tp_rank = 0
112
113
114
115
            
        if self.enable_dp_attn_moe:
            tp_rank = get_moe_tp_rank()

116
117
118
119
120
121
122
123
124
125
126
127
        shard_size = self.data.shape[self.output_dim]
        loaded_weight = loaded_weight.narrow(self.output_dim,
                                             tp_rank * shard_size, shard_size)
        assert self.data.shape == loaded_weight.shape
        self.data.copy_(loaded_weight)

    def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):

        shard_offset = kwargs.get("shard_offset")
        shard_size = kwargs.get("shard_size")
        if isinstance(
                self,
128
129
            (PackedColumnParameter,
             PackedvLLMParameter)) and self.packed_dim == self.output_dim:
130
131
132
133
134
135
            shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
                shard_offset=shard_offset, shard_size=shard_size)

        param_data = self.data

        tp_rank = get_tensor_model_parallel_rank()
王敏's avatar
王敏 committed
136
137
        if self.expect_tp_size == 1:
            tp_rank = 0
138
139
140
141
            
        if self.enable_dp_attn_moe:
            tp_rank = get_moe_tp_rank()

142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
        param_data = param_data.narrow(self.output_dim, shard_offset,
                                       shard_size)
        loaded_weight = loaded_weight.narrow(self.output_dim,
                                             tp_rank * shard_size, shard_size)
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

    def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):

        shard_offset = kwargs.get("shard_offset")
        shard_size = kwargs.get("shard_size")
        shard_id = kwargs.get("shard_id")
        num_heads = kwargs.get("num_heads")

        if isinstance(
                self,
158
159
            (PackedColumnParameter,
             PackedvLLMParameter)) and self.output_dim == self.packed_dim:
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
            shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
                shard_offset=shard_offset, shard_size=shard_size)

        param_data = self.data
        tp_rank = get_tensor_model_parallel_rank()
        shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads
        param_data = param_data.narrow(self.output_dim, shard_offset,
                                       shard_size)
        loaded_weight = loaded_weight.narrow(self.output_dim,
                                             shard_id * shard_size, shard_size)

        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)


175
class RowvLLMParameter(BasevLLMParameter):
176
    """
177
178
179
180
    Parameter class defining weight_loading functionality
    (load_row_parallel_weight) for parameters being loaded
    into linear layers with row parallel functionality.
    Requires an input_dim to be defined.
181
182
183
184
185
    """

    def __init__(self, input_dim: int, **kwargs):
        self._input_dim = input_dim
        super().__init__(**kwargs)
王敏's avatar
王敏 committed
186
        self.expect_tp_size = -1
187
        self.enable_dp_attn_moe = False
188
189
190
191
192
193
194

    @property
    def input_dim(self):
        return self._input_dim

    def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
        tp_rank = get_tensor_model_parallel_rank()
王敏's avatar
王敏 committed
195
196
        if self.expect_tp_size == 1:
            tp_rank = 0
197
198
199
200
            
        if self.enable_dp_attn_moe:
            tp_rank = get_moe_tp_rank()

201
202
203
204
205
206
207
208
209
210
211
        shard_size = self.data.shape[self.input_dim]
        loaded_weight = loaded_weight.narrow(self.input_dim,
                                             tp_rank * shard_size, shard_size)

        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)

        assert self.data.shape == loaded_weight.shape
        self.data.copy_(loaded_weight)


212
213
214
215
216
217
218
219
220
class ModelWeightParameter(_ColumnvLLMParameter, RowvLLMParameter):
    """
    Parameter class for linear layer weights. Uses both column and
    row parallelism.
    """
    pass


class GroupQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter):
221
222
    """
    Parameter class for weight scales loaded for weights with
223
    grouped quantization. Uses both column and row parallelism.
224
225
226
227
228
229
230
    """
    pass


class ChannelQuantScaleParameter(_ColumnvLLMParameter):
    """
    Parameter class for weight scales loaded for weights with
231
    channel-wise quantization. Equivalent to _ColumnvLLMParameter.
232
233
234
235
236
237
238
239
240
    """
    pass


class PerTensorScaleParameter(BasevLLMParameter):
    """
    Parameter class for scales where the number of scales is
    equivalent to the number of logical matrices in fused linear
    layers (e.g. for QKV, there are 3 scales loaded from disk).
241
    This is relevant to weights with per-tensor quantization.
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
    Adds functionality to map the scalers to a shard during
    weight loading. 

    Note: additional parameter manipulation may be handled 
    for each quantization config specifically, within 
    process_weights_after_loading 
    """

    def __init__(self, **kwargs):
        self.qkv_idxs = {"q": 0, "k": 1, "v": 2}
        super().__init__(**kwargs)

    def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
        if isinstance(shard_id, int):
            return shard_id

258
259
        # if not int, assume shard_id for qkv
        # map to int and return
260
261
262
263
        assert isinstance(shard_id, str)
        assert shard_id in self.qkv_idxs
        return self.qkv_idxs[shard_id]

264
265
266
267
268
    # For row parallel layers, no sharding needed
    # load weight into parameter as is
    def load_row_parallel_weight(self, *args, **kwargs):
        super().load_row_parallel_weight(*args, **kwargs)

269
270
271
272
273
274
275
    def load_merged_column_weight(self, *args, **kwargs):
        self._load_into_shard_id(*args, **kwargs)

    def load_qkv_weight(self, *args, **kwargs):
        self._load_into_shard_id(*args, **kwargs)

    def load_column_parallel_weight(self, *args, **kwargs):
276
        super().load_row_parallel_weight(*args, **kwargs)
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298

    def _load_into_shard_id(self, loaded_weight: torch.Tensor,
                            shard_id: Union[str, int], **kwargs):
        """
        Slice the parameter data based on the shard id for 
        loading.
        """

        param_data = self.data
        shard_id = self._shard_id_as_int(shard_id)

        # AutoFP8 scales do not have a shape
        # compressed-tensors scales do have a shape
        if len(loaded_weight.shape) != 0:
            assert loaded_weight.shape[0] == 1
            loaded_weight = loaded_weight[0]

        param_data = param_data[shard_id]
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)


299
300
301
302
303
304
305
306
class PackedColumnParameter(_ColumnvLLMParameter):
    """
    Parameter for model parameters which are packed on disk
    and support column parallelism only. See PackedvLLMParameter
    for more details on the packed properties.
    """

    def __init__(self,
307
                 packed_factor: Union[int, Fraction],
308
309
                 packed_dim: int,
                 marlin_tile_size: Optional[int] = None,
310
                 bitblas_tile_size: Optional[int] = None,
311
312
313
314
                 **kwargs):
        self._packed_factor = packed_factor
        self._packed_dim = packed_dim
        self._marlin_tile_size = marlin_tile_size
315
        self._bitblas_tile_size = bitblas_tile_size
316
317
318
319
320
321
322
323
324
325
326
327
328
329
        super().__init__(**kwargs)

    @property
    def packed_dim(self):
        return self._packed_dim

    @property
    def packed_factor(self):
        return self._packed_factor

    @property
    def marlin_tile_size(self):
        return self._marlin_tile_size

330
331
332
333
    @property
    def bitblas_tile_size(self):
        return self._bitblas_tile_size

334
335
336
337
338
    def adjust_shard_indexes_for_packing(self, shard_size, shard_offset):
        return _adjust_shard_indexes_for_packing(
            shard_size=shard_size,
            shard_offset=shard_offset,
            packed_factor=self.packed_factor,
339
340
            marlin_tile_size=self.marlin_tile_size,
            bitblas_tile_size=self.bitblas_tile_size)
341
342


343
344
345
346
347
348
349
class PackedvLLMParameter(ModelWeightParameter):
    """
    Parameter for model weights which are packed on disk.
    Example: GPTQ Marlin weights are int4 or int8, packed into int32.
    Extends the ModelWeightParameter to take in the
    packed factor, the packed dimension, and optionally, marlin
    tile size for marlin kernels. Adjusts the shard_size and 
350
    shard_offset for fused linear layers model weight loading
351
352
353
354
    by accounting for packing and optionally, marlin tile size.
    """

    def __init__(self,
355
                 packed_factor: Union[int, Fraction],
356
357
                 packed_dim: int,
                 marlin_tile_size: Optional[int] = None,
358
                 bitblas_tile_size: Optional[int] = None,
359
360
361
                 **kwargs):
        self._packed_factor = packed_factor
        self._packed_dim = packed_dim
362
        self._marlin_tile_size = marlin_tile_size
363
        self._bitblas_tile_size = bitblas_tile_size
364
365
366
367
368
369
370
371
372
373
374
        super().__init__(**kwargs)

    @property
    def packed_dim(self):
        return self._packed_dim

    @property
    def packed_factor(self):
        return self._packed_factor

    @property
375
376
    def marlin_tile_size(self):
        return self._marlin_tile_size
377

378
379
380
381
    @property
    def bitblas_tile_size(self):
        return self._bitblas_tile_size

382
    def adjust_shard_indexes_for_packing(self, shard_size, shard_offset):
383
384
385
386
        return _adjust_shard_indexes_for_packing(
            shard_size=shard_size,
            shard_offset=shard_offset,
            packed_factor=self.packed_factor,
387
388
            marlin_tile_size=self.marlin_tile_size,
            bitblas_tile_size=self.bitblas_tile_size)
389
390


391
392
393
394
395
396
397
398
399
class BlockQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter):
    """
    Parameter class for weight scales loaded for weights with
    block-wise quantization. Uses both column and row parallelism.
    """

    pass


400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
def permute_param_layout_(param: BasevLLMParameter, input_dim: int,
                          output_dim: int, **kwargs) -> BasevLLMParameter:
    """
    Permute a parameter's layout to the specified input and output dimensions, 
    useful for forcing the parameter into a known layout, for example, if I need
    a packed (quantized) weight matrix to be in the layout 
        {input_dim = 0, output_dim = 1, packed_dim = 0}
    then I can call:
        permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
    to ensure x is in the correct layout (permuting it to the correct layout if 
    required, asserting if it cannot get it to the correct layout)
    """

    curr_input_dim = getattr(param, "input_dim", None)
    curr_output_dim = getattr(param, "output_dim", None)

    if curr_input_dim is None or curr_output_dim is None:
        assert param.data.dim() == 2,\
            "permute_param_layout_ only supports 2D parameters when either "\
            "input_dim or output_dim is not set"

    # if one of the dimensions is not set, set it to the opposite of the other
    #  we can only do this since we asserted the parameter is 2D above
    if curr_input_dim is None:
        assert curr_output_dim is not None,\
            "either input or output dim must be set"
        curr_input_dim = (curr_output_dim + 1) % 2
    if curr_output_dim is None:
        assert curr_input_dim is not None,\
            "either input or output dim must be set"
        curr_output_dim = (curr_input_dim + 1) % 2

    # create permutation from the current layout to the layout with
    # self.input_dim at input_dim and self.output_dim at output_dim preserving
    # other dimensions
    perm = [
        i for i in range(param.data.dim())
        if i not in [curr_input_dim, curr_output_dim]
    ]
    perm.insert(input_dim, curr_input_dim)
    perm.insert(output_dim, curr_output_dim)

    if "packed_dim" in kwargs:
        assert hasattr(param, "packed_dim") and\
            param.packed_dim == perm[kwargs["packed_dim"]],\
            "permute_param_layout_ currently doesn't support repacking"

    param.data = param.data.permute(*perm)
    if hasattr(param, "_input_dim"):
        param._input_dim = input_dim
    if hasattr(param, "_output_dim"):
        param._output_dim = output_dim
    if "packed_dim" in kwargs and hasattr(param, "_packed_dim"):
        param._packed_dim = kwargs["packed_dim"]

    return param


458
459
460
461
462
def _adjust_shard_indexes_for_marlin(shard_size, shard_offset,
                                     marlin_tile_size):
    return shard_size * marlin_tile_size, shard_offset * marlin_tile_size


463
464
465
466
467
def _adjust_shard_indexes_for_bitblas(shard_size, shard_offset,
                                      bitblas_tile_size):
    return shard_size // bitblas_tile_size, shard_offset // bitblas_tile_size


468
def _adjust_shard_indexes_for_packing(shard_size, shard_offset, packed_factor,
469
                                      marlin_tile_size, bitblas_tile_size):
470
471
472
473
474
475
476
    shard_size = shard_size // packed_factor
    shard_offset = shard_offset // packed_factor
    if marlin_tile_size is not None:
        return _adjust_shard_indexes_for_marlin(
            shard_size=shard_size,
            shard_offset=shard_offset,
            marlin_tile_size=marlin_tile_size)
477
478
479
480
481
482
483
    elif bitblas_tile_size is not None:
        return _adjust_shard_indexes_for_bitblas(
            shard_size=shard_size,
            shard_offset=shard_offset,
            bitblas_tile_size=bitblas_tile_size)

    return shard_size, shard_offset