Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7ecee343
Unverified
Commit
7ecee343
authored
Aug 01, 2024
by
Jee Jee Li
Committed by
GitHub
Jul 31, 2024
Browse files
[Kernel][RFC] Refactor the punica kernel based on Triton (#5036)
parent
7eb0cb4a
Changes
47
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
1196 additions
and
191 deletions
+1196
-191
vllm/lora/ops/sgmv_expand_slice.py
vllm/lora/ops/sgmv_expand_slice.py
+205
-0
vllm/lora/ops/sgmv_shrink.py
vllm/lora/ops/sgmv_shrink.py
+189
-0
vllm/lora/ops/utils.py
vllm/lora/ops/utils.py
+46
-0
vllm/lora/punica.py
vllm/lora/punica.py
+581
-184
vllm/triton_utils/__init__.py
vllm/triton_utils/__init__.py
+2
-1
vllm/triton_utils/libentry.py
vllm/triton_utils/libentry.py
+167
-0
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+6
-6
No files found.
vllm/lora/ops/sgmv_expand_slice.py
0 → 100644
View file @
7ecee343
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
import
torch
import
triton
import
triton.language
as
tl
from
vllm.triton_utils
import
libentry
@
libentry
()
@
triton
.
jit
def
_sgmv_expand_slice_kernel
(
input_ptr
,
lora_ptr
,
out_ptr
,
N
,
K
,
b_seq_start_loc
,
seq_lens
,
lora_indices
,
xm_stride
,
xk_stride
,
# 1
l0_stride
,
# hidden_size*max_rank
lora_k_stride
,
lora_n_stride
,
cm_stride
,
cn_stride
,
slice_offset
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
EVEN_K
:
tl
.
constexpr
,
ADD_INPUTS
:
tl
.
constexpr
,
CAST_TYPE
:
tl
.
constexpr
,
):
"""
Similar to the 'sgmv_expand' operator, but with an added parameter
'slice_offset'. The reason for not reusing the 'sgmv_expand' operator
might be that in the future, we could implement a fusion operator to
achieve the current functionality instead of having to call it multiple
times.
"""
pid
=
tl
.
program_id
(
axis
=
0
)
cur_batch
=
tl
.
program_id
(
axis
=
1
)
cta_n_num
=
tl
.
cdiv
(
N
,
BLOCK_N
)
pid_m
=
pid
//
cta_n_num
pid_n
=
pid
%
cta_n_num
M
=
tl
.
load
(
seq_lens
+
cur_batch
)
if
pid_m
*
BLOCK_M
>
M
:
return
lora_index
=
tl
.
load
(
lora_indices
+
cur_batch
)
if
lora_index
==
-
1
:
return
cur_seq_start
=
tl
.
load
(
b_seq_start_loc
+
cur_batch
)
offset_m
=
tl
.
arange
(
0
,
BLOCK_M
)
+
pid_m
*
BLOCK_M
offset_n
=
tl
.
arange
(
0
,
BLOCK_N
)
+
pid_n
*
BLOCK_N
offset_k
=
tl
.
arange
(
0
,
BLOCK_K
)
ram
=
tl
.
max_contiguous
(
tl
.
multiple_of
(
offset_m
%
M
,
BLOCK_M
),
BLOCK_M
)
rbn
=
tl
.
max_contiguous
(
tl
.
multiple_of
(
offset_n
%
N
,
BLOCK_N
),
BLOCK_N
)
a_ptr
=
(
input_ptr
+
cur_seq_start
*
xm_stride
+
ram
[:,
None
]
*
xm_stride
+
offset_k
[
None
,
:]
*
xk_stride
,
)
b_ptr
=
(
lora_ptr
+
l0_stride
*
lora_index
+
offset_k
[:,
None
]
*
lora_n_stride
+
rbn
[
None
,
:]
*
lora_k_stride
)
accumulator
=
tl
.
zeros
((
BLOCK_M
,
BLOCK_N
),
dtype
=
tl
.
float32
)
for
k
in
range
(
tl
.
cdiv
(
K
,
BLOCK_K
)):
if
EVEN_K
:
tiled_a
=
tl
.
load
(
a_ptr
)
tiled_b
=
tl
.
load
(
b_ptr
)
else
:
tiled_a
=
tl
.
load
(
a_ptr
,
mask
=
offset_k
[
None
,
:]
<
K
-
k
*
BLOCK_K
,
other
=
0
)
tiled_b
=
tl
.
load
(
b_ptr
,
mask
=
offset_k
[:,
None
]
<
K
-
k
*
BLOCK_K
,
other
=
0
)
if
CAST_TYPE
:
tiled_a
=
tiled_a
.
to
(
lora_ptr
.
dtype
.
element_ty
)
accumulator
+=
tl
.
dot
(
tiled_a
,
tiled_b
,
)
a_ptr
+=
BLOCK_K
*
xk_stride
b_ptr
+=
BLOCK_K
*
lora_n_stride
tiled_c
=
accumulator
.
to
(
lora_ptr
.
dtype
.
element_ty
)
offset_cm
=
cur_seq_start
+
tl
.
arange
(
0
,
BLOCK_M
)
+
pid_m
*
BLOCK_M
offset_cn
=
tl
.
arange
(
0
,
BLOCK_N
)
+
pid_n
*
BLOCK_N
+
slice_offset
c_ptr
=
(
out_ptr
+
offset_cm
[:,
None
]
*
cm_stride
+
offset_cn
[
None
,
:]
*
cn_stride
)
M
=
tl
.
load
(
seq_lens
+
cur_batch
)
c_mask
=
(
offset_cm
[:,
None
]
<
(
cur_seq_start
+
M
))
&
(
offset_cn
[
None
,
:]
<
(
slice_offset
+
N
))
if
ADD_INPUTS
:
tiled_out
=
tl
.
load
(
c_ptr
,
mask
=
c_mask
)
tiled_c
+=
tiled_out
tl
.
store
(
c_ptr
,
tiled_c
,
mask
=
c_mask
)
@
torch
.
inference_mode
()
def
sgmv_expand_slice
(
inputs
:
torch
.
Tensor
,
lora_b_weights
:
torch
.
Tensor
,
output_tensor
:
torch
.
Tensor
,
b_seq_start_loc
:
torch
.
Tensor
,
seq_len_tensor
:
torch
.
Tensor
,
lora_indices_tensor
:
torch
.
Tensor
,
batches
:
int
,
max_seq_length
:
int
,
slice_offset
:
int
,
slice_size
:
int
,
add_inputs
:
bool
=
False
,
):
"""_summary_
Args:
inputs (torch.Tensor): input tensor
lora_b_weights (torch.Tensor): lora'a weight
output_tensor (torch.Tensor): output tensor
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
sequence lengths of the sequences in the batch, used to index
into sequence. E.g.,if the sequence length is [4, 6], it is
[0, 4, 10].
seq_len_tensor (torch.Tensor): (batch_size,). record the sequence
length of the sequences in the batch
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch. An index of -1 means no lora should be
applied.
batches (int): batch size
max_seq_length (int): The max sequence lengths of the sequences
in the batch
slice_offst (int): output_tensor's offst
slice_size (int): current output_tensor's size
add_inputs (bool, optional): Defaults to False. adds the final lora
results to the output..
"""
assert
inputs
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
assert
lora_b_weights
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
]
assert
inputs
.
size
(
1
)
==
lora_b_weights
.
size
(
-
1
)
assert
b_seq_start_loc
.
size
(
0
)
==
batches
assert
lora_indices_tensor
.
size
(
0
)
==
batches
assert
slice_size
==
lora_b_weights
.
size
(
-
2
)
assert
inputs
.
is_contiguous
()
assert
output_tensor
.
is_contiguous
()
if
lora_b_weights
.
ndim
==
4
:
# shape:(lora_num,1,size,rank)
assert
lora_b_weights
.
size
(
1
)
==
1
lora_b_weights
=
lora_b_weights
.
squeeze
(
dim
=
1
)
else
:
assert
lora_b_weights
.
ndim
==
3
# shape:(lora_num,size,rank)
assert
lora_b_weights
.
is_contiguous
()
# TODO tuning this config
N
,
K
=
lora_b_weights
.
shape
[
-
2
:]
# K= rank,N=hidden_size
BLOCK_M
=
32
BLOCK_N
=
32
BLOCK_K
=
16
EVEN_K
=
K
%
BLOCK_K
==
0
ADD_INPUTS
=
add_inputs
CAST_TYPE
=
False
if
inputs
.
dtype
==
torch
.
float32
and
lora_b_weights
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
]:
CAST_TYPE
=
True
grid
=
(
triton
.
cdiv
(
max_seq_length
,
BLOCK_M
)
*
triton
.
cdiv
(
N
,
BLOCK_N
),
batches
,
)
_sgmv_expand_slice_kernel
[
grid
](
inputs
,
lora_b_weights
,
output_tensor
,
N
,
K
,
b_seq_start_loc
,
seq_len_tensor
,
lora_indices_tensor
,
inputs
.
stride
(
0
),
inputs
.
stride
(
1
),
lora_b_weights
.
stride
(
0
),
lora_b_weights
.
stride
(
1
),
lora_b_weights
.
stride
(
2
),
output_tensor
.
stride
(
0
),
output_tensor
.
stride
(
1
),
slice_offset
,
BLOCK_M
,
BLOCK_N
,
BLOCK_K
,
EVEN_K
,
ADD_INPUTS
,
CAST_TYPE
,
)
return
vllm/lora/ops/sgmv_shrink.py
0 → 100644
View file @
7ecee343
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
import
torch
import
triton
import
triton.language
as
tl
from
vllm.triton_utils
import
libentry
@
libentry
()
@
triton
.
jit
def
_sgmv_shrink_kernel
(
input_ptr
,
lora_ptr
,
out_ptr
,
N
,
K
,
b_seq_start_loc
,
seq_lens
,
lora_indices
,
scaling
,
xm_stride
,
# hidden_size
xk_stride
,
# 1
l0_stride
,
# hidden_size*max_rank
lora_k_stride
,
lora_n_stride
,
cm_stride
,
cn_stride
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
EVEN_K
:
tl
.
constexpr
,
SPLIT_K
:
tl
.
constexpr
,
):
"""
The sgmv's shrink triton kernel is based on GroupGEMM+SPLIT-K.
The GEMM of Multi-LoRA can be considered as GroupGEMM. Additionally,
introducing SPLIT-K can improve performance
"""
pid
=
tl
.
program_id
(
axis
=
0
)
pid_sk
=
tl
.
program_id
(
axis
=
1
)
cur_batch
=
tl
.
program_id
(
axis
=
2
)
cta_n_num
=
tl
.
cdiv
(
N
,
BLOCK_N
)
pid_m
=
pid
//
cta_n_num
pid_n
=
pid
%
cta_n_num
M
=
tl
.
load
(
seq_lens
+
cur_batch
)
if
pid_m
*
BLOCK_M
>
M
:
return
lora_index
=
tl
.
load
(
lora_indices
+
cur_batch
)
if
lora_index
==
-
1
:
return
cur_seq_start
=
tl
.
load
(
b_seq_start_loc
+
cur_batch
)
offset_m
=
tl
.
arange
(
0
,
BLOCK_M
)
+
pid_m
*
BLOCK_M
offset_n
=
tl
.
arange
(
0
,
BLOCK_N
)
+
pid_n
*
BLOCK_N
offset_k
=
pid_sk
*
BLOCK_K
+
tl
.
arange
(
0
,
BLOCK_K
)
ram
=
tl
.
max_contiguous
(
tl
.
multiple_of
(
offset_m
%
M
,
BLOCK_M
),
BLOCK_M
)
rbn
=
tl
.
max_contiguous
(
tl
.
multiple_of
(
offset_n
%
N
,
BLOCK_N
),
BLOCK_N
)
a_ptr
=
(
input_ptr
+
cur_seq_start
*
xm_stride
+
ram
[:,
None
]
*
xm_stride
+
offset_k
[
None
,
:]
*
xk_stride
)
b_ptr
=
(
lora_ptr
+
l0_stride
*
lora_index
+
rbn
[
None
,
:]
*
lora_k_stride
+
offset_k
[:,
None
]
*
lora_n_stride
)
accumulator
=
tl
.
zeros
((
BLOCK_M
,
BLOCK_N
),
dtype
=
tl
.
float32
)
for
k
in
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_K
*
SPLIT_K
)):
if
EVEN_K
:
tiled_a
=
tl
.
load
(
a_ptr
)
tiled_b
=
tl
.
load
(
b_ptr
)
else
:
k_remaining
=
K
-
k
*
(
BLOCK_K
*
SPLIT_K
)
tiled_a
=
tl
.
load
(
a_ptr
,
mask
=
offset_k
[
None
,
:]
<
k_remaining
,
other
=
0.0
)
tiled_b
=
tl
.
load
(
b_ptr
,
mask
=
offset_k
[:,
None
]
<
k_remaining
,
other
=
0.0
)
accumulator
+=
tl
.
dot
(
tiled_a
,
tiled_b
)
a_ptr
+=
BLOCK_K
*
SPLIT_K
*
xk_stride
b_ptr
+=
BLOCK_K
*
SPLIT_K
*
lora_n_stride
offset_cm
=
cur_seq_start
+
tl
.
arange
(
0
,
BLOCK_M
)
+
pid_m
*
BLOCK_M
offset_cn
=
tl
.
arange
(
0
,
BLOCK_N
)
+
pid_n
*
BLOCK_N
c_ptr
=
(
out_ptr
+
offset_cm
[:,
None
]
*
cm_stride
+
offset_cn
[
None
,
:]
*
cn_stride
)
c_mask
=
(
offset_cm
[:,
None
]
<
(
cur_seq_start
+
M
))
&
(
offset_cn
[
None
,
:]
<
N
)
accumulator
*=
scaling
# handles write-back with reduction-splitting
if
SPLIT_K
==
1
:
tl
.
store
(
c_ptr
,
accumulator
,
mask
=
c_mask
)
else
:
tl
.
atomic_add
(
c_ptr
,
accumulator
,
mask
=
c_mask
)
@
torch
.
inference_mode
()
def
sgmv_shrink
(
inputs
:
torch
.
Tensor
,
lora_a_weights
:
torch
.
Tensor
,
output_tensor
:
torch
.
Tensor
,
b_seq_start_loc
:
torch
.
Tensor
,
seq_len_tensor
:
torch
.
Tensor
,
lora_indices_tensor
:
torch
.
Tensor
,
batches
:
int
,
max_seq_length
:
int
,
scaling
:
float
,
):
"""
Args:
inputs (torch.Tensor): input tensor
lora_a_weights (torch.Tensor): lora'a weight
output_tensor (torch.Tensor): output tensor
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
sequence lengths of the sequences in the batch, used to index
into sequence. E.g.,if the sequence length is [4, 6], it is
[0, 4].
seq_len_tensor (torch.Tensor): (batch_size,). record the sequence
length of the sequences in the batch
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch. An index of -1 means no lora should be
applied.
batches (int): batch size
max_seq_length (int): The max sequence lengths of the sequences
in the batch
scaling (float): Scaling factor.
"""
assert
inputs
.
dtype
==
lora_a_weights
.
dtype
assert
inputs
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
assert
lora_a_weights
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
]
assert
inputs
.
size
(
1
)
==
lora_a_weights
.
size
(
-
1
)
assert
b_seq_start_loc
.
size
(
0
)
==
batches
assert
lora_indices_tensor
.
size
(
0
)
==
batches
assert
inputs
.
is_contiguous
()
if
lora_a_weights
.
ndim
==
4
:
# shape:(lora_num,1,rank, size)
assert
lora_a_weights
.
size
(
1
)
==
1
lora_a_weights
=
lora_a_weights
.
squeeze
(
dim
=
1
)
else
:
assert
lora_a_weights
.
ndim
==
3
# shape:(lora_num,rank, size)
assert
lora_a_weights
.
is_contiguous
()
assert
output_tensor
.
is_contiguous
()
# TODO tuning this config
N
,
K
=
lora_a_weights
.
shape
[
-
2
:]
# K=hidden_size,N=rank
BLOCK_M
=
32
BLOCK_N
=
16
BLOCK_K
=
32
SPLIT_K
=
8
EVEN_K
=
K
%
(
BLOCK_K
*
SPLIT_K
)
==
0
grid
=
(
triton
.
cdiv
(
max_seq_length
,
BLOCK_M
)
*
triton
.
cdiv
(
N
,
BLOCK_N
),
SPLIT_K
,
batches
,
)
_sgmv_shrink_kernel
[
grid
](
inputs
,
lora_a_weights
,
output_tensor
,
N
,
K
,
b_seq_start_loc
,
seq_len_tensor
,
lora_indices_tensor
,
scaling
,
inputs
.
stride
(
0
),
inputs
.
stride
(
1
),
lora_a_weights
.
stride
(
0
),
lora_a_weights
.
stride
(
1
),
lora_a_weights
.
stride
(
2
),
output_tensor
.
stride
(
0
),
output_tensor
.
stride
(
1
),
BLOCK_M
,
BLOCK_N
,
BLOCK_K
,
EVEN_K
,
SPLIT_K
,
)
return
vllm/lora/ops/utils.py
0 → 100644
View file @
7ecee343
import
functools
from
typing
import
Dict
@
functools
.
lru_cache
def
_get_op_configs
(
op_type
:
str
,
batch
:
int
,
hidden_size
:
int
):
# TODO: add optimal configurations
return
None
def
_check_divisibility
(
hidden_size
:
int
):
# The bgmv_expand kernel requires that the hidden_size be divisible by
# the number below.
divisibility
=
[
2
,
4
,
8
,
16
,
32
,
64
]
divisibility
.
sort
(
reverse
=
True
)
for
div
in
divisibility
:
if
hidden_size
%
div
==
0
:
return
div
# hidden_size is an odd number
return
1
def
_get_default_config
(
op_type
:
str
,
batch
:
int
,
hidden_size
:
int
):
if
op_type
==
"expand"
:
return
{
"BLOCK_N"
:
256
,
"SPLIT_N"
:
_check_divisibility
(
hidden_size
),
"num_warps"
:
8
}
else
:
return
{
"BLOCK_K"
:
256
,
"SPLIT_K"
:
64
,
"num_warps"
:
8
}
def
get_lora_op_configs
(
op_type
:
str
,
batch
:
int
,
hidden_size
:
int
)
->
Dict
[
str
,
int
]:
"""Inspired by `fused_moe_kernel`
The return value will be a dictionary mapping an irregular grid of batch
sizes and hidden_size to configurations of the bgmv-related kernel.
NOTE: It currently only supports the default configuration. We plan to
generate optimal configurations for different hardware in the future using
scripts similar to `benchmark_moe.py`.
"""
config
=
_get_op_configs
(
op_type
,
batch
,
hidden_size
)
if
not
config
:
config
=
_get_default_config
(
op_type
,
batch
,
hidden_size
)
return
config
vllm/lora/punica.py
View file @
7ecee343
# Based on code from https://github.com/punica-ai/punica
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
from
typing
import
Optional
from
typing
import
TYPE_CHECKING
,
Callable
,
List
,
Optional
,
Tuple
,
Union
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
HAS_TRITON
if
HAS_TRITON
:
from
vllm.lora.ops.bgmv_expand
import
bgmv_expand
from
vllm.lora.ops.bgmv_expand_slice
import
bgmv_expand_slice
from
vllm.lora.ops.bgmv_shrink
import
bgmv_shrink
from
vllm.lora.ops.sgmv_expand
import
sgmv_expand
from
vllm.lora.ops.sgmv_expand_slice
import
sgmv_expand_slice
from
vllm.lora.ops.sgmv_shrink
import
sgmv_shrink
def
_check_punica_support
():
if
ops
.
is_custom_op_supported
(
"_punica_C::dispatch_bgmv"
):
return
if
TYPE_CHECKING
:
# avoid circuit import
from
vllm.lora.layers
import
LoRAMapping
from
vllm.lora.models
import
LongContextLoRAContext
if
current_platform
.
get_device_capability
()
<
(
8
,
0
):
raise
ImportError
(
"punica LoRA kernels require compute capability >= 8.0"
)
else
:
raise
ImportError
(
"punica LoRA kernels could not be imported. If you built vLLM "
"from source, make sure VLLM_INSTALL_PUNICA_KERNELS=1 env var "
"was set."
)
def
bgmv
(
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
w_t_all
:
torch
.
Tensor
,
indicies
:
torch
.
LongTensor
,
layer_idx
:
int
,
scale
:
float
,
):
"""
Semantics:
y[i] += (
x[i].unsqueeze(0)
@ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
* scale
).squeeze(0)
Args:
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
x: Shape: `[B, H1]`. Input vectors.
w_t_all: Shape: `[None, L, H2, H1]`. All of the transposed weight
matrices.
indicies: Shape: `[B]`. Indices of the weight matrices.
layer_idx: Layer index of the weight matrices.
scale: Scaling factor.
def
compute_meta
(
token_lora_tensor
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
int
,
int
,
bool
]:
"""
Get the information required for the sgmv kernel. With the features:
1. If consecutive requests in the batch use the same LoRA, this function
will combine them into a single request, improving sgmv kernel inference
performance.
2. At the beginning of each prefill stage inference, recalculations are
needed based on the input, but only once.
"""
_check_punica_support
()
ops
.
dispatch_bgmv
(
y
,
x
,
w_t_all
,
indicies
,
layer_idx
,
scale
)
lora_indices_tensor
,
seq_length_tensor
=
torch
.
unique_consecutive
(
token_lora_tensor
,
return_counts
=
True
)
cum_result
=
torch
.
cumsum
(
seq_length_tensor
,
dim
=
0
)
b_seq_start_tensor
=
torch
.
zeros_like
(
seq_length_tensor
)
b_seq_start_tensor
[
1
:].
copy_
(
cum_result
[:
-
1
])
max_length
=
seq_length_tensor
.
max
().
item
()
batch_size
=
lora_indices_tensor
.
size
(
0
)
no_lora
=
False
# -1 means no lora should be applied. Use `no_lora` to determine whether
# the current step requires LoRA. If LoRA is not needed, the prefill stage
# does not need to launch the triton kernel, which can improve performance
if
batch_size
==
1
and
lora_indices_tensor
==
-
1
:
no_lora
=
True
return
(
b_seq_start_tensor
,
seq_length_tensor
,
lora_indices_tensor
,
batch_size
,
max_length
,
no_lora
)
def
dispatch_bgmv_low_level
(
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
w_t_all
:
torch
.
Tensor
,
indicies
:
torch
.
LongTensor
,
layer_idx
:
int
,
scale
:
float
,
y_offset
:
int
,
y_slice_size
:
int
):
"""
Same as `bgmv` but you can operate on slices of y.
Pass whole y, define y_offset and y_slice_size.
Semantics:
y[i] += (
x[i].unsqueeze(0)
@ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
* scale
).squeeze(0)
# TODO see if this can be vectorized
def
convert_mapping
(
mapping
:
"LoRAMapping"
,
lora_index_to_id
:
List
[
Optional
[
int
]],
max_loras
:
int
,
vocab_size
:
int
,
extra_vocab_size
:
int
,
long_lora_context
:
Optional
[
"LongContextLoRAContext"
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
List
[
int
]]:
"""Converts LoRAMapping to index tensors.
Args:
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
x: Shape: `[B, H1]`. Input vectors.
w_t_all: Shape: `[None, L, y_slice_size, H1]`. Column partition of
all of the transposed LoRA matrices.
indicies: Shape: `[B]`. Indices of the LoRA weights.
layer_idx: Layer index of LoRA weights.
scale: Scaling factor.
y_offset: Offset to apply to the starting column of y.
y_slice_size: Size of the y column slice.
mapping: LoRAMapping mapping rows in a batch to LoRA ids.
lora_index_to_id: List mapping LoRA ids to LoRA indices.
max_loras: Maximum number of LoRAs.
vocab_size: Model vocab size.
extra_vocab_size: Extra vocab size each LoRA can have.
long_lora_context: Passed if there are long context lora in a batch.
Returns:
A tuple of tensors:
base_indices: Tensor of shape [batch_size] mapping batch rows to
LoRA indices.
sampler_indices: Tensor of shape [batch_size] mapping requests to
LoRA indices for sampler. For generation, this will be the
same as base_indicies. For prefill, this will map requests
to LoRA indices.
sampler_indices_padded: Tensor of shape [batch_size] mapping
requests to LoRA indices for sampler with padding.
Same as sampler_indicies, but -1 is replaced with
max_loras.
embeddings_indices: Tensor of shape [2, batch_size] mapping
requests to embedding indices. First row is for embeddings
added by the LoRAs, second row is for the LoRA.lora_a
embeddings.
long_lora_indices: Tensor of shape [batch_size] mapping
requests to RoPE offsets and rot dims for long LoRAs.
None if long context lora doesn't exist.
indices_len: List of lengths of the above tensors. It contains
(base_indices, sampler_indices, sampler_indices_padded,
embeddings_indices, long_lora_indices).
"""
_check_punica_support
()
ops
.
dispatch_bgmv_low_level
(
y
,
x
,
w_t_all
,
indicies
,
layer_idx
,
scale
,
x
.
size
(
1
),
y_slice_size
,
y_offset
,
)
index_mapping_indices
:
List
[
int
]
=
list
(
mapping
.
index_mapping
).
copy
()
embedding_indices
=
index_mapping_indices
.
copy
()
lora_indices
=
index_mapping_indices
.
copy
()
long_lora_offsets
:
Optional
[
torch
.
Tensor
]
=
None
if
long_lora_context
:
long_lora_offsets
=
torch
.
zeros
(
len
(
index_mapping_indices
),
device
=
"cuda"
,
dtype
=
torch
.
long
)
prompt_mapping
:
List
[
int
]
=
[
lora_index_to_id
.
index
(
x
)
if
x
>
0
else
-
1
for
x
in
mapping
.
prompt_mapping
]
lora_idx
=
None
for
i
in
range
(
len
(
index_mapping_indices
)):
# TODO index can be slow. optimize
lora_idx
=
(
lora_index_to_id
.
index
(
index_mapping_indices
[
i
])
if
index_mapping_indices
[
i
]
>
0
else
-
1
)
embedding_indices
[
i
]
=
lora_idx
if
index_mapping_indices
[
i
]
>
0
else
0
lora_indices
[
i
]
=
lora_idx
if
long_lora_context
:
assert
long_lora_offsets
is
not
None
lora_offset
:
int
=
long_lora_context
.
offsets_by_lora_id
.
get
(
index_mapping_indices
[
i
],
0
)
long_lora_offsets
[
i
]
=
lora_offset
indices_list
:
List
[
Union
[
List
[
int
],
torch
.
Tensor
]]
=
[
index_mapping_indices
,
lora_indices
,
embedding_indices
,
]
if
long_lora_context
:
assert
long_lora_offsets
is
not
None
indices_list
.
append
(
long_lora_offsets
)
indices
=
torch
.
tensor
(
indices_list
,
dtype
=
torch
.
long
,
device
=
"cuda"
)
prompt_mapping_tensor
=
torch
.
tensor
(
prompt_mapping
,
device
=
"cuda"
,
dtype
=
torch
.
long
)
embeddings_indices
=
torch
.
stack
([
indices
[
2
]
*
extra_vocab_size
,
indices
[
2
]
*
(
vocab_size
+
extra_vocab_size
),
])
embeddings_indices
[
embeddings_indices
==
-
1
]
=
max_loras
-
1
base_indices
=
indices
[
1
]
sampler_indices
=
prompt_mapping_tensor
sampler_indices_padded
=
sampler_indices
.
clone
()
sampler_indices_padded
[
sampler_indices_padded
==
-
1
]
=
max_loras
-
1
sampler_indices_padded
=
torch
.
arange
(
0
,
len
(
sampler_indices_padded
),
device
=
"cuda"
,
dtype
=
torch
.
long
)
+
(
sampler_indices_padded
*
len
(
sampler_indices_padded
))
long_lora_indices
=
None
long_lora_indices_len
:
Optional
[
int
]
=
None
if
long_lora_context
:
long_lora_indices
=
indices
[
3
]
long_lora_indices_len
=
long_lora_indices
.
shape
[
-
1
]
# Contain length of indices tensors. Used to index into each tensor.
indices_len
=
[
base_indices
.
shape
[
-
1
],
sampler_indices
.
shape
[
-
1
],
sampler_indices_padded
.
shape
[
-
1
],
embeddings_indices
.
shape
[
-
1
],
]
if
long_lora_indices_len
is
not
None
:
indices_len
.
append
(
long_lora_indices_len
)
else
:
# If long_lora doesn't exist,append None
indices_len
.
append
(
None
)
def
add_lora
(
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
wa_t_all
:
torch
.
Tensor
,
wb_t_all
:
torch
.
Tensor
,
indicies
:
torch
.
LongTensor
,
layer_idx
:
int
,
scale
:
float
,
*
,
buffer
:
Optional
[
torch
.
Tensor
]
=
None
):
"""
Semantics:
y[i] += (
x[i].unsqueeze(0)
@ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
@ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
* scale
).squeeze(0)
return
(
base_indices
,
sampler_indices
,
sampler_indices_padded
,
embeddings_indices
,
long_lora_indices
,
indices_len
,
)
Args:
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
x: Shape: `[B, H1]`. Input vectors.
wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed
LoRA A matrices.
wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed
LoRA B matrices.
indicies: Shape: `[B]`. Indices of the LoRA weights.
layer_idx: Layer index of LoRA weights.
scale: Scaling factor.
buffer: Optional. Shape: `[B, R]`. Temporary buffer.
class
PunicaWrapper
:
"""
_check_punica_support
()
r
=
wb_t_all
.
size
(
-
1
)
if
buffer
is
None
:
# We set the buffer to be float32 by default to avoid
# numerical inaccuracies that would otherwise happen
# due to downcasting.
buffer
=
torch
.
zeros
((
x
.
size
(
0
),
r
),
dtype
=
torch
.
float32
,
device
=
x
.
device
)
ops
.
dispatch_bgmv
(
buffer
,
x
,
wa_t_all
,
indicies
,
layer_idx
,
1.0
)
ops
.
dispatch_bgmv
(
y
,
buffer
,
wb_t_all
,
indicies
,
layer_idx
,
scale
)
def
add_lora_slice
(
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
wa_t_all
:
torch
.
Tensor
,
wb_t_all
:
torch
.
Tensor
,
indicies
:
torch
.
LongTensor
,
layer_idx
:
int
,
scale
:
float
,
y_offset
:
int
,
y_slice_size
:
int
,
*
,
buffer
:
Optional
[
torch
.
Tensor
]
=
None
):
PunicaWrapper is designed to manage and provide metadata for the punica
kernel. The main function is to maintain the state information for
Multi-LoRA, and to provide the interface for the punica kernel.
"""
Same as `add_lora` but you can operate on slices of y.
Pass whole y, define y_offset and y_slice_size.
Semantics:
y[i] += (
x[i].unsqueeze(0)
@ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
@ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
* scale
).squeeze(0)
def
__init__
(
self
,
max_num_batched_tokens
:
int
,
max_batches
:
int
,
device
:
str
):
self
.
_token_lora_indices
=
torch
.
empty
(
max_num_batched_tokens
,
dtype
=
torch
.
long
,
device
=
device
)
self
.
_sampler_indices
=
torch
.
empty
(
max_num_batched_tokens
,
dtype
=
torch
.
long
,
device
=
device
)
self
.
_sampler_indices_padded
=
torch
.
empty
(
max_num_batched_tokens
,
dtype
=
torch
.
long
,
device
=
device
)
self
.
_embeddings_indices
=
torch
.
empty
(
2
,
max_num_batched_tokens
,
dtype
=
torch
.
long
,
device
=
device
)
self
.
_long_lora_indices
=
torch
.
empty
(
max_num_batched_tokens
,
dtype
=
torch
.
long
,
device
=
device
)
Args:
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
x: Shape: `[B, H1]`. Input vectors.
wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed
LoRA A matrices.
wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed
LoRA B matrices.
indicies: Shape: `[B]`. Indices of the LoRA weights.
layer_idx: Layer index of LoRA weights.
scale: Scaling factor.
y_offset: Offset to apply to the starting column of y.
y_slice_size: Size of the y column slice.
"""
_check_punica_support
()
r
=
wb_t_all
.
size
(
-
1
)
if
buffer
is
None
:
# We set the buffer to be float32 by default to avoid
# numerical inaccuracies that would otherwise happen
# due to downcasting.
buffer
=
torch
.
zeros
((
x
.
size
(
0
),
r
),
dtype
=
torch
.
float32
,
device
=
x
.
device
)
ops
.
dispatch_bgmv_low_level
(
buffer
,
x
,
wa_t_all
,
indicies
,
layer_idx
,
1.0
,
x
.
size
(
1
),
buffer
.
size
(
1
),
0
,
)
ops
.
dispatch_bgmv_low_level
(
y
,
buffer
,
wb_t_all
,
indicies
,
layer_idx
,
scale
,
buffer
.
size
(
1
),
y_slice_size
,
y_offset
,
)
# 5 is the number of indicies tensors.
# base_indices, sampler_indices, sampler_indices_padded,
# embeddings_indices,long_lora_indices
self
.
indices_len
:
List
[
Optional
[
int
]]
=
[
None
]
*
5
# these attributes are the information required for sgmv kernel
self
.
_seq_start_locs
=
torch
.
empty
(
max_batches
,
dtype
=
torch
.
long
,
device
=
device
)
self
.
_seq_lengths
=
torch
.
empty
(
max_batches
,
dtype
=
torch
.
long
,
device
=
device
)
self
.
_lora_indices_per_batch
=
torch
.
empty
(
max_batches
,
dtype
=
torch
.
long
,
device
=
device
)
self
.
max_length
:
int
=
0
self
.
batch_size
:
int
=
-
1
self
.
is_prefill
=
False
self
.
no_lora
=
False
def
update_metadata
(
self
,
mapping
:
"LoRAMapping"
,
lora_index_to_id
:
List
[
Optional
[
int
]],
max_loras
:
int
,
vocab_size
:
int
,
extra_vocab_size
:
int
,
long_lora_context
:
Optional
[
"LongContextLoRAContext"
]
=
None
,
):
self
.
_update_base_metadata
(
mapping
,
lora_index_to_id
,
max_loras
,
vocab_size
,
extra_vocab_size
,
long_lora_context
)
if
mapping
.
is_prefill
:
# Update metadata required for prefill-related operators.
self
.
_update_prefill_metada
(
self
.
token_lora_indices
)
self
.
is_prefill
=
True
else
:
self
.
is_prefill
=
False
def
_update_base_metadata
(
self
,
mapping
:
"LoRAMapping"
,
lora_index_to_id
:
List
[
Optional
[
int
]],
max_loras
:
int
,
vocab_size
:
int
,
extra_vocab_size
:
int
,
long_lora_context
:
Optional
[
"LongContextLoRAContext"
]
=
None
,
):
(
base_indices
,
sampler_indices
,
sampler_indices_padded
,
embeddings_indices
,
long_lora_offsets_tensor
,
indices_len
,
)
=
convert_mapping
(
mapping
,
lora_index_to_id
,
max_loras
,
vocab_size
,
extra_vocab_size
,
long_lora_context
,
)
self
.
_token_lora_indices
[:
base_indices
.
shape
[
0
]].
copy_
(
base_indices
)
self
.
_sampler_indices
[:
sampler_indices
.
shape
[
0
]].
copy_
(
sampler_indices
)
self
.
_sampler_indices_padded
[:
sampler_indices_padded
.
shape
[
0
]].
copy_
(
sampler_indices_padded
)
self
.
_embeddings_indices
[:
embeddings_indices
.
shape
[
0
],
:
embeddings_indices
.
shape
[
1
]].
copy_
(
embeddings_indices
)
if
long_lora_offsets_tensor
is
not
None
:
self
.
_long_lora_indices
[:
long_lora_offsets_tensor
.
shape
[
0
]].
copy_
(
long_lora_offsets_tensor
)
else
:
self
.
_long_lora_indices
.
zero_
()
self
.
indices_len
[:]
=
indices_len
def
_update_prefill_metada
(
self
,
token_lora_tensor
:
torch
.
Tensor
)
->
None
:
(
b_seq_start_tensor
,
seq_length_tensor
,
lora_indices_tensor
,
batch_size
,
max_length
,
no_lora
)
=
compute_meta
(
token_lora_tensor
)
self
.
_seq_start_locs
[:
b_seq_start_tensor
.
shape
[
0
]].
copy_
(
b_seq_start_tensor
)
self
.
_seq_lengths
[:
seq_length_tensor
.
shape
[
0
]].
copy_
(
seq_length_tensor
)
self
.
_lora_indices_per_batch
[:
lora_indices_tensor
.
shape
[
0
]].
copy_
(
lora_indices_tensor
)
self
.
batch_size
=
batch_size
self
.
max_length
=
max_length
self
.
no_lora
=
no_lora
@
property
def
prefill_metadata
(
self
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
int
,
int
]:
"""
This property provides a convenient way to access the necessary
metadata for prefill-related kernel computations.
1. seq_start_locs: Tensor of sequence start positions
2. seq_lengths: Tensor of sequence lengths
3. lora_indices_per_batch: Tensor of lora indices, and an index of
-1 means no lora should be applied.
4. batch_size: batch size after clustering identical lora indices
5. max_length: The maximum sequence length in the batch
"""
return
(
self
.
_seq_start_locs
[:
self
.
batch_size
],
self
.
_seq_lengths
[:
self
.
batch_size
],
self
.
_lora_indices_per_batch
[:
self
.
batch_size
],
self
.
batch_size
,
self
.
max_length
)
@
property
def
token_lora_indices
(
self
)
->
torch
.
Tensor
:
"""
This property provides the lora indices corresponding to each token
in the batch. An index of -1 means no lora should be applied.
"""
token_lora_len
=
self
.
indices_len
[
0
]
return
self
.
_token_lora_indices
[:
token_lora_len
]
@
property
def
sampler_indices
(
self
)
->
torch
.
Tensor
:
"""
This property is used to access the lora indices specifically for
LogitsProcessorWithLoRA
"""
sampler_indices_len
=
self
.
indices_len
[
1
]
return
self
.
_sampler_indices
[:
sampler_indices_len
]
@
property
def
sampler_indices_padded
(
self
)
->
torch
.
Tensor
:
"""
This property provides access to padded sampler indices
"""
indices_padded_len
=
self
.
indices_len
[
2
]
return
self
.
_sampler_indices_padded
[:
indices_padded_len
]
@
property
def
embeddings_indices
(
self
)
->
torch
.
Tensor
:
"""
This property provides access to the indices used for lora embeddings,
specifically for VocabParallelEmbeddingWithLoRA
"""
embeddings_indices_len
=
self
.
indices_len
[
3
]
return
self
.
_embeddings_indices
[:,
:
embeddings_indices_len
]
@
property
def
long_lora_indices
(
self
)
->
torch
.
Tensor
:
"""
This property provides access to the indices used for long context
lora, specifically for LinearScalingRotaryEmbeddingWithLora
"""
long_lora_len
=
self
.
indices_len
[
4
]
return
self
.
_long_lora_indices
[:
long_lora_len
]
def
shrink_prefill
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
w_t_all
:
torch
.
Tensor
,
scale
:
float
,
):
#No LoRA request, so return directly
if
self
.
no_lora
:
return
sgmv_shrink
(
x
,
w_t_all
,
y
,
*
self
.
prefill_metadata
,
scale
,
)
def
shrink_decode
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
w_t_all
:
torch
.
Tensor
,
scale
:
float
,
):
bgmv_shrink
(
x
,
w_t_all
,
y
,
self
.
token_lora_indices
,
scale
)
def
expand_prefill
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
w_t_all
:
torch
.
Tensor
,
add_input
:
bool
,
):
#No LoRA request, so return directly
if
self
.
no_lora
:
return
sgmv_expand
(
x
,
w_t_all
,
y
,
*
self
.
prefill_metadata
,
add_input
,
)
def
expand_decode
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
w_t_all
:
torch
.
Tensor
,
add_input
:
bool
,
):
bgmv_expand
(
x
,
w_t_all
,
y
,
self
.
token_lora_indices
,
add_input
)
def
expand_slice_prefill
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
w_t_all
:
torch
.
Tensor
,
y_offset
:
Optional
[
int
],
y_slice_size
:
Optional
[
int
],
add_input
:
bool
,
):
#No LoRA request, so return directly
if
self
.
no_lora
:
return
sgmv_expand_slice
(
x
,
w_t_all
,
y
,
*
self
.
prefill_metadata
,
y_offset
,
y_slice_size
,
add_input
,
)
def
expand_slice_decode
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
w_t_all
:
torch
.
Tensor
,
y_offset
:
Optional
[
int
],
y_slice_size
:
Optional
[
int
],
add_input
:
bool
,
):
bgmv_expand_slice
(
x
,
w_t_all
,
y
,
self
.
token_lora_indices
,
y_offset
,
y_slice_size
,
add_input
)
def
add_shrink
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
w_t_all
:
torch
.
Tensor
,
scale
:
float
,
):
"""
Perform the ` y+=x@w_t_all` computation, which is suitable for the
GEMM of lora'a.
When `is_prefill is` true, it indicates that it is currently the
prefill stage, and the `shrink_prefill` function should be called.
Otherwise, it is the decode stage, and the shrink_decode function
should be called.
"""
shrink_fun
:
Callable
=
(
self
.
shrink_prefill
if
self
.
is_prefill
else
self
.
shrink_decode
)
shrink_fun
(
y
,
x
,
w_t_all
,
scale
)
def
add_expand
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
w_t_all
:
torch
.
Tensor
,
add_input
:
bool
=
True
,
):
"""
Perform the ` y+=x@w_t_all` computation, which is suitable for the
GEMM of lora'b.
When `is_prefill` is true, it indicates that it is currently the
prefill stage, and the `expand_prefill` function should be called.
Otherwise, it is the decode stage, and the expand_decode function
should be called.
"""
expand_fun
:
Callable
=
(
self
.
expand_prefill
if
self
.
is_prefill
else
self
.
expand_decode
)
expand_fun
(
y
,
x
,
w_t_all
,
add_input
)
def
add_expand_slice
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
w_t_all
:
torch
.
Tensor
,
y_offset
:
Optional
[
int
],
y_slice_size
:
Optional
[
int
],
add_input
:
bool
=
True
):
"""
Similar to `add_expand`
"""
expand_slice_fun
:
Callable
=
(
self
.
expand_slice_prefill
if
self
.
is_prefill
else
self
.
expand_slice_decode
)
expand_slice_fun
(
y
,
x
,
w_t_all
,
y_offset
,
y_slice_size
,
add_input
)
def
add_lora
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
wa_t_all
:
torch
.
Tensor
,
wb_t_all
:
torch
.
Tensor
,
scale
:
float
,
y_offset
:
Optional
[
int
]
=
None
,
y_slice_size
:
Optional
[
int
]
=
None
,
*
,
buffer
:
Optional
[
torch
.
Tensor
]
=
None
)
->
None
:
"""
Semantics:
y[i] += (
x[i].unsqueeze(0)
@ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
@ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
* scale
).squeeze(0)
Args:
y (torch.Tensor): Output tensor. Will be changed in-place.
x (torch.Tensor): Input tensor
wa_t_all (torch.Tensor): lora_a's weight
wb_t_all (torch.Tensor): lora_b's weight
scale (float): Scaling factor.
y_offset (Optional[int], optional): Offset to apply to the starting
column of y.
y_slice_size (Optional[int], optional): Size of the y column slice..
buffer (Optional[torch.Tensor], optional): Defaults to None.
"""
y_org
=
y
y
=
y
.
view
(
-
1
,
y
.
shape
[
-
1
])
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
r
=
wb_t_all
.
size
(
-
1
)
if
buffer
is
None
:
# We set the buffer to be float32 by default ,refer to:
# https://github.com/triton-lang/triton/issues/1387
buffer
=
torch
.
zeros
((
x
.
size
(
0
),
r
),
dtype
=
torch
.
float32
,
device
=
x
.
device
)
self
.
add_shrink
(
buffer
,
x
,
wa_t_all
,
scale
)
if
y_offset
is
None
and
y_slice_size
is
None
:
self
.
add_expand
(
y
,
buffer
,
wb_t_all
,
add_input
=
True
)
else
:
self
.
add_expand_slice
(
y
,
buffer
,
wb_t_all
,
y_offset
,
y_slice_size
,
add_input
=
True
)
y
=
y
.
view_as
(
y_org
)
def
add_lora_packed_nslice
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
lora_a_stacked
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
],
lora_b_stacked
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
],
scale
:
float
,
output_slices
:
Tuple
[
int
,
...])
->
None
:
"""
Applies lora to each input. Similar to add_lora, This method is
used for layers that are composed of multiple sublayers
(slices) packed together.
"""
y_org
=
y
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
y
=
y
.
view
(
-
1
,
y
.
shape
[
-
1
])
offset_left
=
0
# TODO fuse these kernels
for
slice_idx
in
range
(
len
(
output_slices
)):
self
.
add_lora
(
y
,
x
,
lora_a_stacked
[
slice_idx
],
lora_b_stacked
[
slice_idx
],
scale
,
offset_left
,
output_slices
[
slice_idx
])
offset_left
+=
output_slices
[
slice_idx
]
y
=
y
.
view_as
(
y_org
)
def
add_lora_logits
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
wa_t_all
:
torch
.
Tensor
,
wb_t_all
:
torch
.
Tensor
,
scale
,
*
,
buffer
:
Optional
[
torch
.
Tensor
]
=
None
)
->
None
:
"""
LogitsProcessorWithLoRA always using bgmv
"""
y_org
=
y
y
=
y
.
view
(
-
1
,
y
.
shape
[
-
1
])
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
r
=
wb_t_all
.
size
(
-
1
)
if
buffer
is
None
:
# We set the buffer to be float32 by default ,refer to:
# https://github.com/triton-lang/triton/issues/1387
buffer
=
torch
.
zeros
((
x
.
size
(
0
),
r
),
dtype
=
torch
.
float32
,
device
=
x
.
device
)
bgmv_shrink
(
x
,
wa_t_all
,
buffer
,
self
.
sampler_indices
,
scale
)
bgmv_expand
(
buffer
,
wb_t_all
,
y
,
self
.
sampler_indices
,
add_inputs
=
True
)
y
=
y
.
view_as
(
y_org
)
vllm/triton_utils/__init__.py
View file @
7ecee343
...
...
@@ -6,5 +6,6 @@ if HAS_TRITON:
from
vllm.triton_utils.custom_cache_manager
import
(
maybe_set_triton_cache_manager
)
from
vllm.triton_utils.libentry
import
libentry
__all__
+=
[
"maybe_set_triton_cache_manager"
]
__all__
+=
[
"maybe_set_triton_cache_manager"
,
"libentry"
]
vllm/triton_utils/libentry.py
0 → 100644
View file @
7ecee343
# Copied From https://github.com/FlagOpen/FlagGems
import
inspect
import
triton
class
LibEntry
(
triton
.
KernelInterface
):
def
__init__
(
self
,
fn
,
):
self
.
fn
=
fn
self
.
arg_names
=
fn
.
arg_names
self
.
divisibility
=
16
self
.
kernel_cache
=
dict
()
fn
=
self
.
fn
while
not
isinstance
(
fn
,
triton
.
runtime
.
JITFunction
):
fn
=
fn
.
fn
self
.
jit_function
:
triton
.
runtime
.
JITFunction
=
fn
self
.
specialize_indices
=
[
p
.
num
for
p
in
self
.
jit_function
.
params
if
not
p
.
is_constexpr
and
not
p
.
do_not_specialize
]
self
.
do_not_specialize_indices
=
[
p
.
num
for
p
in
self
.
jit_function
.
params
if
not
p
.
is_constexpr
and
p
.
do_not_specialize
]
def
key
(
self
,
spec_args
,
dns_args
,
const_args
):
spec_key
=
[(
arg
.
dtype
,
arg
.
data_ptr
()
%
self
.
divisibility
==
0
)
if
hasattr
(
arg
,
"data_ptr"
)
else
(
type
(
arg
),
arg
)
for
arg
in
spec_args
]
dns_key
=
[
arg
.
dtype
if
hasattr
(
arg
,
"data_ptr"
)
else
type
(
arg
)
if
not
isinstance
(
arg
,
int
)
else
"i32"
if
-
(
2
**
31
)
<=
arg
and
arg
<=
2
**
31
-
1
else
"u64"
if
2
**
63
<=
arg
and
arg
<=
2
**
64
-
1
else
"i64"
for
arg
in
dns_args
]
# const args passed by position
return
tuple
(
spec_key
+
dns_key
+
const_args
)
def
run
(
self
,
*
args
,
**
kwargs
):
grid
=
kwargs
[
"grid"
]
# collect all the arguments
spec_args
=
[]
# specialize arguments
dns_args
=
[]
# do not specialize arguments
const_args
=
[]
# constexpr arguments
k_args
=
[]
# kernel arguments
for
i
,
arg
in
enumerate
(
args
):
if
i
in
self
.
specialize_indices
:
k_args
.
append
(
arg
)
spec_args
.
append
(
arg
)
elif
i
in
self
.
do_not_specialize_indices
:
k_args
.
append
(
arg
)
dns_args
.
append
(
arg
)
else
:
const_args
.
append
(
arg
)
for
p
in
self
.
jit_function
.
params
[
len
(
args
):]:
if
p
.
name
in
kwargs
:
val
=
kwargs
[
p
.
name
]
elif
p
.
default
is
inspect
.
_empty
:
continue
else
:
val
=
p
.
default
if
p
.
is_constexpr
:
const_args
.
append
(
val
)
elif
p
.
do_not_specialize
:
dns_args
.
append
(
val
)
k_args
.
append
(
val
)
else
:
spec_args
.
append
(
val
)
k_args
.
append
(
val
)
entry_key
=
self
.
key
(
spec_args
,
dns_args
,
const_args
)
if
entry_key
not
in
self
.
kernel_cache
:
# compile the kernel also completes the related computations
kernel
=
self
.
fn
.
run
(
*
args
,
**
kwargs
)
fn
=
self
.
fn
# collect constexpr arguments for grid computation
constexprs
=
{}
while
not
isinstance
(
fn
,
triton
.
runtime
.
JITFunction
):
if
isinstance
(
fn
,
triton
.
runtime
.
Autotuner
):
config
=
fn
.
best_config
constexprs
[
"num_warps"
]
=
config
.
num_warps
constexprs
[
"num_stages"
]
=
config
.
num_stages
constexprs
[
"num_ctas"
]
=
config
.
num_ctas
constexprs
=
{
**
constexprs
,
**
config
.
kwargs
}
elif
isinstance
(
fn
,
triton
.
runtime
.
Heuristics
):
for
v
,
heur
in
fn
.
values
.
items
():
constexprs
[
v
]
=
heur
({
**
dict
(
zip
(
fn
.
arg_names
,
args
)),
**
kwargs
,
**
constexprs
,
})
else
:
raise
RuntimeError
(
"Invalid Runtime Function"
)
fn
=
fn
.
fn
# In vLLM, certain kernels like fused_moe_kernel get the
# best_config(as kwargs) from a configuration json file, rather
# than using Autotuner & Heuristics. Therefore, all their constexprs
# (tl.constexpr) are assigned values through the following loop.
for
p
in
self
.
jit_function
.
params
:
if
p
.
is_constexpr
and
p
.
name
not
in
constexprs
:
constexprs
[
p
.
name
]
=
p
.
default
#default=inspect._empty
self
.
kernel_cache
[
entry_key
]
=
(
kernel
,
constexprs
)
else
:
# load kernel from cache directly
kernel
,
constexprs
=
self
.
kernel_cache
[
entry_key
]
if
callable
(
grid
):
# collect all arguments to the grid fn,ie:
# 1. args,
# 2. kwargs,
# 3. all all other captured arguments in CompiledKernel from
# Autotunner & Heuristics when kwargs & captured args conflict,
# captured args have higher priority
# 4. We must filter out captured args with default value firstly
constexprs
=
{
k
:
v
for
k
,
v
in
constexprs
.
items
()
if
v
is
not
inspect
.
_empty
}
meta
=
{
**
dict
(
zip
(
self
.
arg_names
,
args
)),
**
kwargs
,
**
constexprs
,
}
grid
=
grid
(
meta
)
if
isinstance
(
grid
,
tuple
):
grid
=
grid
+
(
1
,
1
)
elif
isinstance
(
grid
,
list
):
grid
=
grid
+
[
1
,
1
]
kernel
[
grid
[
0
:
3
]](
*
k_args
)
# maintaining the same return type as the JITFunction.run
return
kernel
def
libentry
():
"""
Decorator for triton library entries.
Motivation:
The runtime overhead of Triton kernels is the reason for the lower
performance of small kernels, particularly evident with smaller models.
Using this decorator can reduce Triton runtime overhead.
How:
The `run` function of JITFunction needs to accomplish:
- Parameter binding using inspect
- KernelArg type wrapping
- Cache key calculation
When dealing with small size, these steps can become bottlenecks in
Triton runtime. Libentry simplifies these steps to reduce runtime
overhead, thereby improving the runtime expenses of small kernels.
NOTE:
When Triton is upgraded to version 3.0.0, libentry can be removed,
see: https://github.com/vllm-project/vllm/pull/5036#issuecomment-2243396245
"""
def
decorator
(
fn
):
return
LibEntry
(
fn
)
return
decorator
vllm/worker/model_runner.py
View file @
7ecee343
...
...
@@ -578,9 +578,9 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
for
inter_data
in
self
.
inter_data_list
])
lora_mapping
=
LoRAMapping
(
lora_index_mapping
,
lora_prompt_mapping
,
)
**
dict
(
index_mapping
=
lora_index_mapping
,
prompt_mapping
=
lora_prompt_mapping
,
is_prefill
=
not
self
.
decode_only
)
)
# Prompt adapter data.
prompt_adapter_requests
:
Set
[
PromptAdapterRequest
]
=
set
()
...
...
@@ -1152,9 +1152,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
if
self
.
lora_config
:
lora_mapping
=
LoRAMapping
(
[
0
]
*
batch_size
,
[
0
]
*
batch_size
,
)
**
dict
(
index_mapping
=
[
0
]
*
batch_size
,
prompt_mapping
=
[
0
]
*
batch_size
,
is_prefill
=
False
)
)
self
.
set_active_loras
(
set
(),
lora_mapping
)
if
self
.
prompt_adapter_config
:
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment