utils.py 1.92 KB
Newer Older
Jee Jee Li's avatar
Jee Jee Li committed
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

Jee Jee Li's avatar
Jee Jee Li committed
4
5
6
7
8
9
10
from dataclasses import dataclass

import torch
import torch.nn as nn


@dataclass
11
12
13
class LoRAMapping:
    index_mapping: tuple[int, ...]
    prompt_mapping: tuple[int, ...]
Jee Jee Li's avatar
Jee Jee Li committed
14
15
    is_prefill: bool = False

16
17
18
19
    def __post_init__(self):
        self.index_mapping = tuple(self.index_mapping)
        self.prompt_mapping = tuple(self.prompt_mapping)

Jee Jee Li's avatar
Jee Jee Li committed
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47

def _get_lora_device(base_layer: nn.Module) -> torch.device:
    # code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34
    """Returns the device for where to place the LoRA tensors."""
    # unquantizedLinear
    if hasattr(base_layer, "weight"):
        return base_layer.weight.device
    # Compressed Tensor
    elif hasattr(base_layer, "weight_packed"):
        return base_layer.weight_packed.device
    # GPTQ/AWQ
    elif hasattr(base_layer, "qweight"):
        return base_layer.qweight.device
    # HQQ marlin
    elif hasattr(base_layer, "W_q"):
        return base_layer.W_q.device
    else:
        raise ValueError(f"Unsupported base layer: {base_layer}")


def _not_fully_sharded_can_replace(can_replace):
    """
    decorator which adds the condition of not using fully sharded loras
    intended to wrap can_replace_layer()
    """

    def dec(*args, **kwargs):
        decorate = kwargs.pop("decorate") if "decorate" in kwargs else True
48
        condition = not kwargs["lora_config"].fully_sharded_loras if decorate else True
Jee Jee Li's avatar
Jee Jee Li committed
49
50
51
52
53
54
55
56
57
58
59
60
        return can_replace(*args, **kwargs) and condition

    return dec


def _fully_sharded_can_replace(can_replace):
    """
    decorator which adds the condition of fully sharded loras
    intended to wrap can_replace_layer()
    """

    def dec(*args, **kwargs):
61
62
63
        return (
            can_replace(*args, **kwargs) and kwargs["lora_config"].fully_sharded_loras
        )
Jee Jee Li's avatar
Jee Jee Li committed
64
65

    return dec