parameter.py 23.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Callable, Hashable
5
from fractions import Fraction
6
from weakref import WeakValueDictionary
7
8
9

import torch
from torch.nn import Parameter
10
11
import vllm.envs as envs

12

13
14
15
16
from vllm.distributed import (
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
)
17
18
19
from vllm.logger import init_logger

__all__ = [
20
21
22
23
24
25
26
27
    "BasevLLMParameter",
    "PackedvLLMParameter",
    "PerTensorScaleParameter",
    "ModelWeightParameter",
    "ChannelQuantScaleParameter",
    "GroupQuantScaleParameter",
    "PackedColumnParameter",
    "RowvLLMParameter",
28
29
30
31
32
33
34
35
36
37
38
39
]

logger = init_logger(__name__)


class BasevLLMParameter(Parameter):
    """
    Base parameter for vLLM linear layers. Extends the torch.nn.parameter
    by taking in a linear weight loader. Will copy the loaded weight
    into the parameter when the provided weight loader is called.
    """

40
    def __new__(cls, data: torch.Tensor | None, **kwargs):
41
42
43
44
45
46
47
48
49
50
51
52
        return super().__new__(cls, data=data, requires_grad=False)

    def __init__(self, data: torch.Tensor, weight_loader: Callable):
        """
        Initialize the BasevLLMParameter

        :param data: torch tensor with the parameter data
        :param weight_loader: weight loader callable

        :returns: a torch.nn.parameter
        """

53
54
55
56
57
58
59
60
61
        # During weight loading, we often do something like:
        # narrowed_tensor = param.data.narrow(0, offset, len)
        # narrowed_tensor.copy_(real_weight)
        # expecting narrowed_tensor and param.data to share the same storage.
        # However, on TPUs, narrowed_tensor will lazily propagate to the base
        # tensor, which is param.data, leading to the redundant memory usage.
        # This sometimes causes OOM errors during model loading. To avoid this,
        # we sync the param tensor after its weight loader is called.
        from vllm.platforms import current_platform
62

63
        if current_platform.use_sync_weight_loader():
64
            weight_loader = current_platform.make_synced_weight_loader(weight_loader)
65

66
        self._weight_loader = weight_loader
67
68
        self.tp_rank = get_tensor_model_parallel_rank()
        self.tp_size = get_tensor_model_parallel_world_size()
69
70

    @property
71
72
73
74
    def weight_loader(self) -> Callable:
        # NOTE(@ksayers) some models such as mamba_mixer2 override the
        # weight loader to support custom loading. In the future, model-specific
        # weight loading should be implemented via Model.load_weights. In the
75
        # meantime, support deleting and overriding `weight_loader` attribute
76
        if self._weight_loader is None:
77
78
79
            raise AttributeError(
                f"{self.__class__.__name__} weight_loader attribute has been deleted"
            )
80
81
        return self._weight_loader

82
83
84
85
86
87
88
89
    @weight_loader.setter
    def weight_loader(self, value: Callable):
        self._weight_loader = value

    @weight_loader.deleter
    def weight_loader(self):
        self._weight_loader = None  # type: ignore[assignment]

90
91
92
    def _is_1d_and_scalar(self, loaded_weight: torch.Tensor):
        cond1 = self.data.ndim == 1 and self.data.numel() == 1
        cond2 = loaded_weight.ndim == 0 and loaded_weight.numel() == 1
93
        return cond1 and cond2
94

95
    def _assert_and_load(self, loaded_weight: torch.Tensor):
96
97
98
        assert self.data.shape == loaded_weight.shape or self._is_1d_and_scalar(
            loaded_weight
        )
99
100
        self.data.copy_(loaded_weight)

101
    def load_column_parallel_weight(self, loaded_weight: torch.Tensor, is_quantization: Optional[bool] = False):
102
103
        self._assert_and_load(loaded_weight)

104
    def load_row_parallel_weight(self, loaded_weight: torch.Tensor, is_quantization: Optional[bool] = False):
105
106
107
108
109
110
111
112
        self._assert_and_load(loaded_weight)

    def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
        self._assert_and_load(loaded_weight)

    def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
        self._assert_and_load(loaded_weight)

113
    def _shard_id_as_int(self, shard_id: str | int) -> int:
114
115
116
117
118
119
120
121
122
123
        if isinstance(shard_id, int):
            return shard_id

        # if not int, assume shard_id for qkv
        # map to int and return
        qkv_idxs = {"q": 0, "k": 1, "v": 2}
        assert isinstance(shard_id, str)
        assert shard_id in qkv_idxs
        return qkv_idxs[shard_id]

124
125
126
127
128
129
    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        return super().__torch_function__(func, types, args, kwargs)

130
131
132

class _ColumnvLLMParameter(BasevLLMParameter):
    """
133
    Private class defining weight loading functionality
134
135
136
    (load_merged_column_weight, load_qkv_weight)
    for parameters being loaded into linear layers with column
    parallelism. This includes QKV and MLP layers which are
137
    not already fused on disk. Requires an output dimension
138
139
140
141
142
143
144
145
146
147
148
149
    to be defined. Called within the weight loader of
    each of the column parallel linear layers.
    """

    def __init__(self, output_dim: int, **kwargs):
        self._output_dim = output_dim
        super().__init__(**kwargs)

    @property
    def output_dim(self):
        return self._output_dim

150
    def load_column_parallel_weight(self, loaded_weight: torch.Tensor, is_quantization: Optional[bool] = False):
151
152
153
154
        if not envs.VLLM_USE_NN or len( self.data.shape)==1 or is_quantization:
            shard_size = self.data.shape[self.output_dim]
        else:
            shard_size = self.data.shape[int(not(self.output_dim))]
155
156
157
        loaded_weight = loaded_weight.narrow(
            self.output_dim, self.tp_rank * shard_size, shard_size
        )
158
159
160
        if envs.VLLM_USE_NN and not is_quantization:
            loaded_weight = loaded_weight.t()
            
161
162
163
164
165
166
        assert self.data.shape == loaded_weight.shape
        self.data.copy_(loaded_weight)

    def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
        shard_offset = kwargs.get("shard_offset")
        shard_size = kwargs.get("shard_size")
167
        is_quantization = kwargs.get("is_quantization")
168

169
        # TODO: move these to PackedColumnParameter and PackedvLLMParameter
170
171
172
173
        if (
            isinstance(self, (PackedColumnParameter, PackedvLLMParameter))
            and self.packed_dim == self.output_dim
        ):
174
            shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
175
176
                shard_offset=shard_offset, shard_size=shard_size
            )
177
178
179

        param_data = self.data

180
        if not envs.VLLM_USE_NN or is_quantization:
181
            param_data = param_data.narrow(self.output_dim, shard_offset, shard_size)
182
        else:
183
            param_data = param_data.narrow(int(not(self.output_dim)), shard_offset, shard_size)
184
185
186
        loaded_weight = loaded_weight.narrow(
            self.output_dim, self.tp_rank * shard_size, shard_size
        )
187
188
189
        
        if envs.VLLM_USE_NN and not is_quantization:
            loaded_weight = loaded_weight.t()
190

191
192
193
194
195
196
197
198
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

    def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
        shard_offset = kwargs.get("shard_offset")
        shard_size = kwargs.get("shard_size")
        shard_id = kwargs.get("shard_id")
        num_heads = kwargs.get("num_heads")
199
        is_quantization = kwargs.get("is_quantization")
200

201
        # TODO: move these to PackedColumnParameter and PackedvLLMParameter
202
203
204
205
        if (
            isinstance(self, (PackedColumnParameter, PackedvLLMParameter))
            and self.output_dim == self.packed_dim
        ):
206
            shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
207
208
                shard_offset=shard_offset, shard_size=shard_size
            )
209
210

        param_data = self.data
211
        shard_id = self.tp_rank if shard_id == "q" else self.tp_rank // num_heads
212
        if not envs.VLLM_USE_NN or len(param_data.shape)==1 or is_quantization:
213
            param_data = param_data.narrow(self.output_dim, shard_offset, shard_size)
214
        else:
215
            param_data = param_data.narrow(int(not(self.output_dim)), shard_offset, shard_size)
216
217
218
        loaded_weight = loaded_weight.narrow(
            self.output_dim, shard_id * shard_size, shard_size
        )
219

220
221
222
        if envs.VLLM_USE_NN and not is_quantization:
            loaded_weight = loaded_weight.t()
            
223
224
225
226
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)


227
class RowvLLMParameter(BasevLLMParameter):
228
    """
229
230
231
232
    Parameter class defining weight_loading functionality
    (load_row_parallel_weight) for parameters being loaded
    into linear layers with row parallel functionality.
    Requires an input_dim to be defined.
233
234
235
236
237
238
239
240
241
242
    """

    def __init__(self, input_dim: int, **kwargs):
        self._input_dim = input_dim
        super().__init__(**kwargs)

    @property
    def input_dim(self):
        return self._input_dim

243
    def load_row_parallel_weight(self, loaded_weight: torch.Tensor, is_quantization: Optional[bool] = False):
244
245
246
247
        if not envs.VLLM_USE_NN or is_quantization:
            shard_size = self.data.shape[self.input_dim]
        else:
            shard_size = self.data.shape[int(not(self.input_dim))]
248
249
250
        loaded_weight = loaded_weight.narrow(
            self.input_dim, self.tp_rank * shard_size, shard_size
        )
251
252
253
254

        if len(loaded_weight.shape) == 0:
            loaded_weight = loaded_weight.reshape(1)

255
256
257
        if envs.VLLM_USE_NN and not is_quantization:
            loaded_weight = loaded_weight.t()
            
258
259
260
261
        assert self.data.shape == loaded_weight.shape
        self.data.copy_(loaded_weight)


262
263
264
265
266
class ModelWeightParameter(_ColumnvLLMParameter, RowvLLMParameter):
    """
    Parameter class for linear layer weights. Uses both column and
    row parallelism.
    """
267

268
269
270
271
    pass


class GroupQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter):
272
273
    """
    Parameter class for weight scales loaded for weights with
274
    grouped quantization. Uses both column and row parallelism.
275
    """
276

277
278
279
280
281
282
    pass


class ChannelQuantScaleParameter(_ColumnvLLMParameter):
    """
    Parameter class for weight scales loaded for weights with
283
    channel-wise quantization. Equivalent to _ColumnvLLMParameter.
284
    """
285

286
287
288
289
290
291
292
293
    pass


class PerTensorScaleParameter(BasevLLMParameter):
    """
    Parameter class for scales where the number of scales is
    equivalent to the number of logical matrices in fused linear
    layers (e.g. for QKV, there are 3 scales loaded from disk).
294
    This is relevant to weights with per-tensor quantization.
295
    Adds functionality to map the scalers to a shard during
296
    weight loading.
297

298
299
300
    Note: additional parameter manipulation may be handled
    for each quantization config specifically, within
    process_weights_after_loading
301
302
303
304
305
    """

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

306
307
308
309
310
    # For row parallel layers, no sharding needed
    # load weight into parameter as is
    def load_row_parallel_weight(self, *args, **kwargs):
        super().load_row_parallel_weight(*args, **kwargs)

311
312
313
314
315
316
317
    def load_merged_column_weight(self, *args, **kwargs):
        self._load_into_shard_id(*args, **kwargs)

    def load_qkv_weight(self, *args, **kwargs):
        self._load_into_shard_id(*args, **kwargs)

    def load_column_parallel_weight(self, *args, **kwargs):
318
        super().load_row_parallel_weight(*args, **kwargs)
319

320
    def _load_into_shard_id(
321
        self, loaded_weight: torch.Tensor, shard_id: str | int, **kwargs
322
    ):
323
        """
324
        Slice the parameter data based on the shard id for
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
        loading.
        """

        param_data = self.data
        shard_id = self._shard_id_as_int(shard_id)

        # AutoFP8 scales do not have a shape
        # compressed-tensors scales do have a shape
        if len(loaded_weight.shape) != 0:
            assert loaded_weight.shape[0] == 1
            loaded_weight = loaded_weight[0]

        param_data = param_data[shard_id]
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)


342
343
344
345
346
347
348
class PackedColumnParameter(_ColumnvLLMParameter):
    """
    Parameter for model parameters which are packed on disk
    and support column parallelism only. See PackedvLLMParameter
    for more details on the packed properties.
    """

349
350
    def __init__(
        self,
351
        packed_factor: int | Fraction,
352
        packed_dim: int,
353
354
        marlin_tile_size: int | None = None,
        bitblas_tile_size: int | None = None,
355
356
        **kwargs,
    ):
357
358
359
        self._packed_factor = packed_factor
        self._packed_dim = packed_dim
        self._marlin_tile_size = marlin_tile_size
360
        self._bitblas_tile_size = bitblas_tile_size
361
362
363
364
365
366
367
368
369
370
371
372
373
374
        super().__init__(**kwargs)

    @property
    def packed_dim(self):
        return self._packed_dim

    @property
    def packed_factor(self):
        return self._packed_factor

    @property
    def marlin_tile_size(self):
        return self._marlin_tile_size

375
376
377
378
    @property
    def bitblas_tile_size(self):
        return self._bitblas_tile_size

379
380
381
382
383
    def adjust_shard_indexes_for_packing(self, shard_size, shard_offset):
        return _adjust_shard_indexes_for_packing(
            shard_size=shard_size,
            shard_offset=shard_offset,
            packed_factor=self.packed_factor,
384
            marlin_tile_size=self.marlin_tile_size,
385
386
            bitblas_tile_size=self.bitblas_tile_size,
        )
387
388


389
390
391
392
393
394
class PackedvLLMParameter(ModelWeightParameter):
    """
    Parameter for model weights which are packed on disk.
    Example: GPTQ Marlin weights are int4 or int8, packed into int32.
    Extends the ModelWeightParameter to take in the
    packed factor, the packed dimension, and optionally, marlin
395
    tile size for marlin kernels. Adjusts the shard_size and
396
    shard_offset for fused linear layers model weight loading
397
398
399
    by accounting for packing and optionally, marlin tile size.
    """

400
401
    def __init__(
        self,
402
        packed_factor: int | Fraction,
403
        packed_dim: int,
404
405
        marlin_tile_size: int | None = None,
        bitblas_tile_size: int | None = None,
406
407
        **kwargs,
    ):
408
409
        self._packed_factor = packed_factor
        self._packed_dim = packed_dim
410
        self._marlin_tile_size = marlin_tile_size
411
        self._bitblas_tile_size = bitblas_tile_size
412
413
414
415
416
417
418
419
420
421
422
        super().__init__(**kwargs)

    @property
    def packed_dim(self):
        return self._packed_dim

    @property
    def packed_factor(self):
        return self._packed_factor

    @property
423
424
    def marlin_tile_size(self):
        return self._marlin_tile_size
425

426
427
428
429
    @property
    def bitblas_tile_size(self):
        return self._bitblas_tile_size

430
    def adjust_shard_indexes_for_packing(self, shard_size, shard_offset):
431
432
433
434
        return _adjust_shard_indexes_for_packing(
            shard_size=shard_size,
            shard_offset=shard_offset,
            packed_factor=self.packed_factor,
435
            marlin_tile_size=self.marlin_tile_size,
436
437
            bitblas_tile_size=self.bitblas_tile_size,
        )
438
439


440
441
442
443
444
445
446
447
448
class BlockQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter):
    """
    Parameter class for weight scales loaded for weights with
    block-wise quantization. Uses both column and row parallelism.
    """

    pass


449
450
451
452
453
454
455
456
class SharedWeightParameter(BasevLLMParameter):
    """
    Parameter for weights with many shared tensors across a model

    For example, when applying transforms to the "gate" and "up" partitions of
    `MergedColumnParallelLinear`, the transform weights must stay separate
    tensors in order to allow for tensor memory sharing between layers.
    """
457

458
459
460
461
462
463
464
465
466
467
    # global registry for sharing tensors based on passed `data_key`
    # this dict holds weaksrefs to avoid memory leak after model cleanup
    tensors_registry: WeakValueDictionary = WeakValueDictionary()

    # local container for strong references to shared tensors
    # this set compensates for the fact that torch.nn.Parameter
    # and Parameter subclasses do not hold reliable references to tensors
    local_tensors: set[torch.Tensor]

    # dictionary mapping partition indices to associated parameters
468
    partitions: dict[int, ModelWeightParameter | Parameter]
469
470
471
472
473

    def __new__(cls, **kwargs):
        return super().__new__(cls, data=None, **kwargs)

    def __init__(self, input_dim: int = 1, output_dim: int = 0, **kwargs):
474
        weight_loader: Callable = kwargs.get("weight_loader")  # type: ignore[assignment]
475
476
477
478
479
480
481
        super().__init__(data=None, weight_loader=weight_loader)

        self.local_tensors = set()
        self.partitions = {}
        self.kwargs = {
            "input_dim": input_dim,
            "output_dim": output_dim,
482
            "weight_loader": self._fake_weight_loader,
483
484
485
        }

        if self.tp_size > 1:
486
487
488
489
            raise NotImplementedError(
                f"{self.__class__.__name__} does not "
                "currently support tensor parallelism"
            )
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508

    def add_partition(self, index: int, data_key: Hashable, *args, **kwargs):
        """
        Add a partition to the weight parameter. Partitions whose `data_key`
        is the same will share tensor data

        :param index: index of partition to add
        :param data_key: hashable key used to key shared tensors
        :param *args: arguments for `torch.empty`
        :param **kwargs: keyword arguments for `torch.empty`
        """
        # load (shared) tensor using `data_key`
        if data_key not in self.tensors_registry:
            data = torch.empty(*args, **kwargs)
            self.tensors_registry[data_key] = data
        else:
            data = self.tensors_registry[data_key]

        # create associated model parameter
509
        self.partitions[index] = ModelWeightParameter(data=data, **self.kwargs)  # type: ignore[arg-type]
510
511
512
513
514
515
516
517
518

        # hold local reference, since ModelWeightParameter does not
        # see https://github.com/pytorch/pytorch/issues/75932
        self.local_tensors.add(data)

    def load_column_parallel_weight(self, loaded_weight: torch.Tensor):
        assert len(self.partitions) == 1 and 0 in self.partitions
        partition = self.partitions[0]

519
        ModelWeightParameter.load_column_parallel_weight(partition, loaded_weight)
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536

    def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
        assert len(self.partitions) == 1 and 0 in self.partitions
        partition = self.partitions[0]

        ModelWeightParameter.load_row_parallel_weight(partition, loaded_weight)

    def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
        partition_id = kwargs.pop("shard_id")
        partition_id = self._shard_id_as_int(partition_id)
        partition = self.partitions[partition_id]

        input_dim = self.kwargs.get("input_dim")
        shard_size = partition.data.size(input_dim) // self.tp_size
        shard_offset = self.tp_rank * shard_size

        ModelWeightParameter.load_merged_column_weight(
537
538
            partition, loaded_weight, shard_offset=shard_offset, shard_size=shard_size
        )
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561

    def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
        partition_id = self._shard_id_as_int(kwargs.pop("shard_id"))
        partition = self.partitions[partition_id]

        input_dim = self.kwargs.get("input_dim")
        shard_size = partition.data.size(input_dim) // self.tp_size
        shard_offset = self.tp_rank * shard_size
        shard_id = "q"  # fake first partition
        num_heads = kwargs.get("num_heads")

        ModelWeightParameter.load_qkv_weight(
            partition,
            loaded_weight,
            shard_offset=shard_offset,
            shard_size=shard_size,
            shard_id=shard_id,
            num_heads=num_heads,
        )

    def process_weights_after_loading(self):
        for key in self.partitions:
            self.partitions[key] = torch.nn.Parameter(
562
563
                data=self.partitions[key].data, requires_grad=False
            )
564
565
566

    @property
    def data(self):
567
568
569
570
571
572
        raise ValueError(
            "Accessing `data` of a "
            "`PartitionedModelWeightParameter` is not allowed. "
            "Instead, use `get_partition` to get the weight of "
            "the particular partition you want to access"
        )
573

574
575
576
577
    def _fake_weight_loader(
        self,
        param: BasevLLMParameter,
        loaded_weight: torch.Tensor,
578
        loaded_weight_shard_id: str | int | None,
579
580
581
582
583
584
    ):
        raise ValueError(
            "When loading partition weights of "
            f"{self.__class__.__name__}, use methods provided by "
            f"{self.__class__.__name__}, not partition loader"
        )
585
586


587
588
589
def permute_param_layout_(
    param: BasevLLMParameter, input_dim: int, output_dim: int, **kwargs
) -> BasevLLMParameter:
590
    """
591
    Permute a parameter's layout to the specified input and output dimensions,
592
    useful for forcing the parameter into a known layout, for example, if I need
593
    a packed (quantized) weight matrix to be in the layout
594
595
596
        {input_dim = 0, output_dim = 1, packed_dim = 0}
    then I can call:
        permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
597
    to ensure x is in the correct layout (permuting it to the correct layout if
598
599
600
601
602
603
604
    required, asserting if it cannot get it to the correct layout)
    """

    curr_input_dim = getattr(param, "input_dim", None)
    curr_output_dim = getattr(param, "output_dim", None)

    if curr_input_dim is None or curr_output_dim is None:
605
606
        assert param.data.dim() == 2, (
            "permute_param_layout_ only supports 2D parameters when either "
607
            "input_dim or output_dim is not set"
608
        )
609
610
611
612

    # if one of the dimensions is not set, set it to the opposite of the other
    #  we can only do this since we asserted the parameter is 2D above
    if curr_input_dim is None:
613
        assert curr_output_dim is not None, "either input or output dim must be set"
614
615
        curr_input_dim = (curr_output_dim + 1) % 2
    if curr_output_dim is None:
616
        assert curr_input_dim is not None, "either input or output dim must be set"
617
618
619
620
621
622
        curr_output_dim = (curr_input_dim + 1) % 2

    # create permutation from the current layout to the layout with
    # self.input_dim at input_dim and self.output_dim at output_dim preserving
    # other dimensions
    perm = [
623
        i for i in range(param.data.dim()) if i not in [curr_input_dim, curr_output_dim]
624
625
626
627
628
    ]
    perm.insert(input_dim, curr_input_dim)
    perm.insert(output_dim, curr_output_dim)

    if "packed_dim" in kwargs:
629
630
631
632
        assert (
            hasattr(param, "packed_dim")
            and param.packed_dim == perm[kwargs["packed_dim"]]
        ), "permute_param_layout_ currently doesn't support repacking"
633
634
635
636
637
638
639
640
641
642
643
644

    param.data = param.data.permute(*perm)
    if hasattr(param, "_input_dim"):
        param._input_dim = input_dim
    if hasattr(param, "_output_dim"):
        param._output_dim = output_dim
    if "packed_dim" in kwargs and hasattr(param, "_packed_dim"):
        param._packed_dim = kwargs["packed_dim"]

    return param


645
def _adjust_shard_indexes_for_marlin(shard_size, shard_offset, marlin_tile_size):
646
647
648
    return shard_size * marlin_tile_size, shard_offset * marlin_tile_size


649
def _adjust_shard_indexes_for_bitblas(shard_size, shard_offset, bitblas_tile_size):
650
651
652
    return shard_size // bitblas_tile_size, shard_offset // bitblas_tile_size


653
654
655
def _adjust_shard_indexes_for_packing(
    shard_size, shard_offset, packed_factor, marlin_tile_size, bitblas_tile_size
):
656
657
658
659
660
661
    shard_size = shard_size // packed_factor
    shard_offset = shard_offset // packed_factor
    if marlin_tile_size is not None:
        return _adjust_shard_indexes_for_marlin(
            shard_size=shard_size,
            shard_offset=shard_offset,
662
663
            marlin_tile_size=marlin_tile_size,
        )
664
665
666
667
    elif bitblas_tile_size is not None:
        return _adjust_shard_indexes_for_bitblas(
            shard_size=shard_size,
            shard_offset=shard_offset,
668
669
            bitblas_tile_size=bitblas_tile_size,
        )
670

671
    return shard_size, shard_offset