punica.py 5.17 KB
Newer Older
1
2
3
4
5
6
# Based on code from https://github.com/punica-ai/punica

from typing import Optional

import torch

7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89

def _raise_import_error(e):
    if torch.cuda.get_device_capability() < (8, 0):
        raise ImportError(
            "punica LoRA kernels require compute capability >= 8.0") from e
    else:
        raise ImportError(
            "punica LoRA kernels could not be imported. If you built vLLM "
            "from source, make sure VLLM_INSTALL_PUNICA_KERNELS=1 env var "
            "was set.") from e


def bgmv(
    y: torch.Tensor,
    x: torch.Tensor,
    w_t_all: torch.Tensor,
    indicies: torch.LongTensor,
    layer_idx: int,
    scale: float,
):
    """
    Semantics:
      y[i] += (
          x[i].unsqueeze(0)
          @ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
          * scale
        ).squeeze(0)

    Args:
      y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
      x: Shape: `[B, H1]`. Input vectors.
      w_t_all: Shape: `[None, L, H2, H1]`. All of the transposed weight
        matrices.
      indicies: Shape: `[B]`. Indices of the weight matrices.
      layer_idx: Layer index of the weight matrices.
      scale: Scaling factor.
    """
    try:
        import vllm._punica_C as punica_kernels
    except ImportError as e:
        _raise_import_error(e)

    punica_kernels.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale)


def add_lora(y: torch.Tensor,
             x: torch.Tensor,
             wa_t_all: torch.Tensor,
             wb_t_all: torch.Tensor,
             indicies: torch.LongTensor,
             layer_idx: int,
             scale: float,
             *,
             buffer: Optional[torch.Tensor] = None):
    """
    Semantics:
      y[i] += (
          x[i].unsqueeze(0)
          @ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
          @ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
          * scale
        ).squeeze(0)

    Args:
      y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
      x: Shape: `[B, H1]`. Input vectors.
      wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed
        LoRA A matrices.
      wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed
        LoRA B matrices.
      indicies: Shape: `[B]`. Indices of the LoRA weights.
      layer_idx: Layer index of LoRA weights.
      scale: Scaling factor.
      buffer: Optional. Shape: `[B, R]`. Temporary buffer.
    """
    try:
        import vllm._punica_C as punica_kernels
    except ImportError as e:
        _raise_import_error(e)

    r = wb_t_all.size(-1)
    if buffer is None:
        # We set the buffer to be float32 by default to avoid
90
        # numerical inaccuracies that would otherwise happen
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
        # due to downcasting.
        buffer = torch.zeros((x.size(0), r),
                             dtype=torch.float32,
                             device=x.device)
    punica_kernels.dispatch_bgmv(buffer, x, wa_t_all, indicies, layer_idx, 1.0)
    punica_kernels.dispatch_bgmv(y, buffer, wb_t_all, indicies, layer_idx,
                                 scale)


def add_lora_slice(y: torch.Tensor,
                   x: torch.Tensor,
                   wa_t_all: torch.Tensor,
                   wb_t_all: torch.Tensor,
                   indicies: torch.LongTensor,
                   layer_idx: int,
                   scale: float,
                   y_offset: int,
                   y_slice_size: int,
                   *,
                   buffer: Optional[torch.Tensor] = None):
    """
    Same as `add_lora` but you can operate on slices of y.
    Pass whole y, define y_offset and y_slice_size.

    Semantics:
      y[i] += (
          x[i].unsqueeze(0)
          @ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
          @ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
          * scale
        ).squeeze(0)

    Args:
      y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
      x: Shape: `[B, H1]`. Input vectors.
      wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed
        LoRA A matrices.
      wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed
        LoRA B matrices.
      indicies: Shape: `[B]`. Indices of the LoRA weights.
      layer_idx: Layer index of LoRA weights.
      scale: Scaling factor.
      y_offset: Offset to apply to the starting column of y.
      y_slice_size: Size of the y column slice.
    """
    try:
        import vllm._punica_C as punica_kernels
    except ImportError as e:
        _raise_import_error(e)

    r = wb_t_all.size(-1)
    if buffer is None:
        # We set the buffer to be float32 by default to avoid
        # numerical inaccuracies that would otherwise happen
        # due to downcasting.
        buffer = torch.zeros((x.size(0), r),
                             dtype=torch.float32,
                             device=x.device)
    punica_kernels.dispatch_bgmv_low_level(
        buffer,
        x,
        wa_t_all,
        indicies,
        layer_idx,
        1.0,
        x.size(1),
        buffer.size(1),
        0,
    )
    punica_kernels.dispatch_bgmv_low_level(
        y,
        buffer,
        wb_t_all,
        indicies,
        layer_idx,
        scale,
        buffer.size(1),
        y_slice_size,
        y_offset,
    )