sgmv.py 7.55 KB
Newer Older
xuxzh1's avatar
last  
xuxzh1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
# Origin:   https://github.com/predibase/lorax
# Path:     lorax/server/lorax_server/utils/sgmv.py
# License:  Apache License Version 2.0, January 2004

import os
import warnings
from functools import lru_cache
from typing import List, Tuple

import torch
import torch.nn.functional as F

try:
    import punica_kernels as _kernels

    HAS_SGMV = not bool(os.environ.get("DISABLE_SGMV", ""))
except ImportError:
    warnings.warn("Could not import SGMV kernel from Punica, falling back to loop.")
    _kernels = None
    HAS_SGMV = False


MIN_SGMV_RANK = 8
MIN_RANK_CUSTOM = 16
MAX_RANK_CUSTOM = 128
SGMV_BLOCK_SIZE = 16
BGMV_MAX_RANK = 64


def has_sgmv() -> bool:
    return HAS_SGMV


def pad_rank(t: torch.Tensor, dim: int, world_size: int) -> torch.Tensor:
    """Pad a tensor to the minimum rank for SGMV and the nearest multiple of the SGMV block size."""
    if not has_sgmv():
        return t

    # tensor parallelism will result in effective rank being divided by world_size,
    # so we need to scale the min rank to offset that effect
    min_rank = MIN_SGMV_RANK * world_size

    # if we're at or below the min rank, pad up to the min rank
    # otherwise, pad to the nearest multiple of the block size
    current_rank = t.size(dim)
    target_rank = (
        min_rank
        if current_rank <= min_rank
        else (current_rank + SGMV_BLOCK_SIZE - 1) // SGMV_BLOCK_SIZE * SGMV_BLOCK_SIZE
    )
    if current_rank == target_rank:
        return t

    pad_size = target_rank - current_rank

    # see complicatd pad syntax here: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
    pad = [0, 0] * t.dim()
    pad[(t.dim() - dim - 1) * 2 + 1] = pad_size
    pad = tuple(pad)

    return F.pad(t, pad, mode="constant", value=0.0)


def use_cutlass_shrink(lora_rank: int) -> bool:
    return lora_rank < MIN_RANK_CUSTOM


def orient_for_rank(t: torch.Tensor, rank: int) -> torch.Tensor:
    if MIN_RANK_CUSTOM <= rank <= MAX_RANK_CUSTOM:
        return t.transpose(0, 1)
    return t


# Source: https://github.com/punica-ai/punica/blob/master/src/punica/ops/__init__.py
def add_lora_sgmv_cutlass(
    y: torch.Tensor,
    x: torch.Tensor,
    wa_ptr: torch.Tensor,
    wb_ptr: torch.Tensor,
    s_start: torch.Tensor,
    s_end: torch.Tensor,
    layer_idx: int,
    lora_rank: int,
):
    """
    Semantics:
        y[s[i]:s[i+1]] += x[s[i]:s[i+1]] @ deref(wa_ptr[i]).T @ deref(wb_ptr[i])

    Args:
        y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
        x: Shape: `[B, H1]`. Input vectors.
        wa_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\
            Weight matrix shape: `[num_layers, R, H1]`.
        wb_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\
            Weight matrix shape: `[num_layers, R, H2]`.
        s_start: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices start indices.
        s_end: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices end indices.
        layer_idx: Layer index of the weight matrices.
    """
    if lora_rank < MIN_RANK_CUSTOM or lora_rank > MAX_RANK_CUSTOM:
        # Custom SGMV shrink only supports rank 16, 32, 64, 128
        _add_lora_sgmv_cutlass_legacy(
            y, x, wa_ptr, wb_ptr, s_start, s_end, layer_idx, lora_rank
        )
        return

    tmp1 = torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=x.device)
    tmp2_size = _kernels.sgmv_cutlass_tmp_size(wa_ptr.size(0))
    tmp2 = torch.empty((tmp2_size,), dtype=torch.uint8, device=x.device)
    v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
    _kernels.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp1, layer_idx)
    _kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp2, layer_idx)


def _add_lora_sgmv_cutlass_legacy(
    y: torch.Tensor,
    x: torch.Tensor,
    wa_ptr: torch.Tensor,
    wb_ptr: torch.Tensor,
    s_start: torch.IntTensor,
    s_end: torch.IntTensor,
    layer_idx: int,
    lora_rank: int,
):
    tmp_size = _kernels.sgmv_cutlass_tmp_size(wa_ptr.size(0))
    tmp = torch.empty((tmp_size,), dtype=torch.uint8, device=x.device)
    v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
    _kernels.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
    _kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx)


@lru_cache(maxsize=1)
def get_tmp_tensor(device: torch.device) -> torch.Tensor:
    return torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=device)


@lru_cache(maxsize=32)
def get_tmp_tensor_for_size(size: int, device: torch.device) -> torch.Tensor:
    tmp_size = _kernels.sgmv_cutlass_tmp_size(size)
    return torch.empty((tmp_size,), dtype=torch.uint8, device=device)


def get_tmp_tensor_for_size_no_kernels(size: int, device: torch.device) -> torch.Tensor:
    return torch.empty((size,), dtype=torch.uint8, device=device)


def get_tmp_expand_size(size: int) -> int:
    return _kernels.sgmv_cutlass_tmp_size(size)


def get_tmp_tensors(
    nsegments: int, lora_rank: int, device: torch.device
) -> Tuple[torch.Tensor, torch.Tensor]:
    if use_cutlass_shrink(lora_rank) and has_sgmv():
        tmp = get_tmp_tensor_for_size(nsegments, device)
        return tmp, tmp
    else:
        tmp_shrink = get_tmp_tensor(device)
        tmp_expand = get_tmp_tensor_for_size_no_kernels(nsegments, device)
        return tmp_shrink, tmp_expand


def lora_a_sgmv_cutlass(
    x: torch.Tensor,
    tmp: torch.Tensor,
    wa_ptr: torch.Tensor,
    s_start: torch.IntTensor,
    s_end: torch.IntTensor,
    layer_idx: int,
    lora_rank: int,
) -> torch.Tensor:
    v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
    if MIN_RANK_CUSTOM <= lora_rank <= MAX_RANK_CUSTOM:
        _kernels.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
    else:
        _kernels.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
    return v


def lora_b_sgmv_cutlass(
    y: torch.Tensor,
    v: torch.Tensor,
    tmp: torch.Tensor,
    wb_ptr: torch.Tensor,
    s_start: torch.IntTensor,
    s_end: torch.IntTensor,
    layer_idx: int,
):
    _kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx)


"""
Semantics:
    y[i] += (
        x[i].unsqueeze(0)
        @ wa_T_all[indices[i], layer_idx, :, :].transpose(-1, -2)
        @ wb_T_all[indices[i], layer_idx, :, :].transpose(-1, -2)
        * scale
    ).squeeze(0)

Args:
    y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
    v: Shape: `[B, R]`. Temporary vector.
    x: Shape: `[B, H1]`. Input vectors.
    wa_T_all: Shape: `[None, L, R, H1]`. All of the transposed LoRA A matrices.
    wb_T_all: Shape: `[None, L, H2, R]`. All of the transposed LoRA B matrices.
    indicies: Shape: `[B]`. Indices of the LoRA weights.
    layer_idx: Layer index of LoRA weights.
    scale: Scaling factor.
"""


def add_lora_a_bgmv(
    v: torch.Tensor,
    x: torch.Tensor,
    wa_T_all: torch.Tensor,
    indicies: torch.LongTensor,
    layer_idx: int,
):
    _kernels.dispatch_bgmv(v, x, wa_T_all, indicies, layer_idx, 1.0)


def add_lora_b_bgmv(
    y: torch.Tensor,
    v: torch.Tensor,
    wb_T_all: torch.Tensor,
    indicies: torch.LongTensor,
    layer_idx: int,
):
    _kernels.dispatch_bgmv(y, v, wb_T_all, indicies, layer_idx, 1.0)


def segmented_matmul(
    y: torch.Tensor,
    x: torch.Tensor,
    w: List[torch.Tensor],
    b: List[torch.Tensor],
    s_start: torch.IntTensor,
    s_end: torch.IntTensor,
):
    for i in range(len(w)):
        if s_end[i] - s_start[i] <= 0:
            continue

        xi = x[s_start[i] : s_end[i]]
        wi = w[i]
        bi = b[i]
        y[s_start[i] : s_end[i]] = F.linear(xi, wi, bi)