Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9f698563
Unverified
Commit
9f698563
authored
Aug 16, 2024
by
bnellnm
Committed by
GitHub
Aug 16, 2024
Browse files
[Kernel] register punica functions as torch ops (#7591)
parent
d4f0f17b
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
46 additions
and
42 deletions
+46
-42
vllm/lora/ops/bgmv_expand.py
vllm/lora/ops/bgmv_expand.py
+8
-12
vllm/lora/ops/bgmv_expand_slice.py
vllm/lora/ops/bgmv_expand_slice.py
+8
-12
vllm/lora/ops/bgmv_shrink.py
vllm/lora/ops/bgmv_shrink.py
+9
-12
vllm/lora/ops/sgmv_expand.py
vllm/lora/ops/sgmv_expand.py
+7
-2
vllm/lora/ops/sgmv_expand_slice.py
vllm/lora/ops/sgmv_expand_slice.py
+7
-2
vllm/lora/ops/sgmv_shrink.py
vllm/lora/ops/sgmv_shrink.py
+7
-2
No files found.
vllm/lora/ops/bgmv_expand.py
View file @
9f698563
...
...
@@ -5,8 +5,6 @@ Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
from
typing
import
Dict
,
Optional
import
torch
import
triton
import
triton.language
as
tl
...
...
@@ -86,14 +84,13 @@ def _bgmv_expand_kernel(
@
torch
.
inference_mode
()
def
bgmv_expand
(
def
_
bgmv_expand
(
inputs
:
torch
.
Tensor
,
lora_b_weights
:
torch
.
Tensor
,
output_tensor
:
torch
.
Tensor
,
lora_indices_tensor
:
torch
.
Tensor
,
add_inputs
:
bool
=
True
,
override_config
:
Optional
[
Dict
[
str
,
int
]]
=
None
,
):
)
->
None
:
"""
Args:
inputs (torch.Tensor): input tensor
...
...
@@ -105,10 +102,7 @@ def bgmv_expand(
batches (int): batch size
add_inputs (bool, optional): Defaults to False. adds the final lora
results to the output.
override_config (Optional[Dict[str, int]], optional): Defaults to None.
Triton grid config
"""
assert
inputs
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
assert
lora_b_weights
.
dtype
in
[
torch
.
float16
,
...
...
@@ -138,10 +132,7 @@ def bgmv_expand(
]:
CAST_TYPE
=
True
batches
=
lora_indices_tensor
.
size
(
0
)
if
override_config
:
config
=
override_config
else
:
config
=
get_lora_op_configs
(
"expand"
,
batches
,
N
)
config
=
get_lora_op_configs
(
"expand"
,
batches
,
N
)
grid
=
lambda
META
:
(
META
[
"SPLIT_N"
],
batches
,
...
...
@@ -167,3 +158,8 @@ def bgmv_expand(
**
config
,
)
return
bgmv_expand
=
torch
.
library
.
custom_op
(
"lora::bgmv_expand"
,
_bgmv_expand
,
mutates_args
=
[
"output_tensor"
])
vllm/lora/ops/bgmv_expand_slice.py
View file @
9f698563
...
...
@@ -5,8 +5,6 @@ Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
from
typing
import
Dict
,
Optional
import
torch
import
triton
import
triton.language
as
tl
...
...
@@ -89,7 +87,7 @@ def _bgmv_expand_slice_kernel(
@
torch
.
inference_mode
()
def
bgmv_expand_slice
(
def
_
bgmv_expand_slice
(
inputs
:
torch
.
Tensor
,
lora_b_weights
:
torch
.
Tensor
,
output_tensor
:
torch
.
Tensor
,
...
...
@@ -97,8 +95,7 @@ def bgmv_expand_slice(
slice_offset
:
int
,
slice_size
:
int
,
add_inputs
:
bool
=
True
,
override_config
:
Optional
[
Dict
[
str
,
int
]]
=
None
,
):
)
->
None
:
"""
Args:
inputs (torch.Tensor): input tensor
...
...
@@ -111,10 +108,7 @@ def bgmv_expand_slice(
slice_size (int): current output_tensor's size
batches (int): batch size
add_inputs (bool, optional): Defaults to False.
override_config (Optional[Dict[str, int]], optional): Defaults to None.
Triton grid config
"""
assert
inputs
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
assert
lora_b_weights
.
dtype
in
[
torch
.
float16
,
...
...
@@ -149,10 +143,7 @@ def bgmv_expand_slice(
batches
=
lora_indices_tensor
.
size
(
0
)
if
override_config
:
config
=
override_config
else
:
config
=
get_lora_op_configs
(
"expand"
,
batches
,
N
)
config
=
get_lora_op_configs
(
"expand"
,
batches
,
N
)
grid
=
lambda
META
:
(
META
[
"SPLIT_N"
],
...
...
@@ -180,3 +171,8 @@ def bgmv_expand_slice(
**
config
,
)
return
bgmv_expand_slice
=
torch
.
library
.
custom_op
(
"lora::bgmv_expand_slice"
,
_bgmv_expand_slice
,
mutates_args
=
[
"output_tensor"
])
vllm/lora/ops/bgmv_shrink.py
View file @
9f698563
...
...
@@ -5,8 +5,6 @@ Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
from
typing
import
Dict
,
Optional
import
torch
import
triton
import
triton.language
as
tl
...
...
@@ -78,14 +76,13 @@ def _bgmv_shrink_kernel(
@
torch
.
inference_mode
()
def
bgmv_shrink
(
def
_
bgmv_shrink
(
inputs
:
torch
.
Tensor
,
lora_a_weights
:
torch
.
Tensor
,
output_tensor
:
torch
.
Tensor
,
lora_indices_tensor
:
torch
.
Tensor
,
scaling
:
float
=
1.0
,
override_config
:
Optional
[
Dict
[
str
,
int
]]
=
None
,
):
)
->
None
:
"""
Args:
inputs (torch.Tensor): input tensor
...
...
@@ -96,8 +93,6 @@ def bgmv_shrink(
applied.
batches (int): batch size
scaling (float): Scaling factor.
override_config (Optional[Dict[str, int]], optional): Defaults to None.
Triton grid config
"""
assert
inputs
.
dtype
==
lora_a_weights
.
dtype
assert
inputs
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
...
...
@@ -119,11 +114,8 @@ def bgmv_shrink(
batches
=
lora_indices_tensor
.
size
(
0
)
N
,
K
=
lora_a_weights
.
shape
[
-
2
:]
# K=hidden_size,N=rank
BLOCK_N
=
triton
.
next_power_of_2
(
N
)
if
override_config
:
config
=
override_config
else
:
# First try to load optimal config from the file
config
=
get_lora_op_configs
(
"bgmv_shrink"
,
batches
,
K
)
# First try to load optimal config from the file
config
=
get_lora_op_configs
(
"bgmv_shrink"
,
batches
,
K
)
grid
=
lambda
META
:
(
META
[
"SPLIT_K"
],
...
...
@@ -148,3 +140,8 @@ def bgmv_shrink(
**
config
,
)
return
bgmv_shrink
=
torch
.
library
.
custom_op
(
"lora::bgmv_shrink"
,
_bgmv_shrink
,
mutates_args
=
[
"output_tensor"
])
vllm/lora/ops/sgmv_expand.py
View file @
9f698563
...
...
@@ -97,7 +97,7 @@ def _sgmv_expand_kernel(
@
torch
.
inference_mode
()
def
sgmv_expand
(
def
_
sgmv_expand
(
inputs
:
torch
.
Tensor
,
lora_b_weights
:
torch
.
Tensor
,
output_tensor
:
torch
.
Tensor
,
...
...
@@ -107,7 +107,7 @@ def sgmv_expand(
batches
:
int
,
max_seq_length
:
int
,
add_inputs
:
bool
=
False
,
):
)
->
None
:
"""
Args:
inputs (torch.Tensor): input tensor
...
...
@@ -190,3 +190,8 @@ def sgmv_expand(
CAST_TYPE
,
)
return
sgmv_expand
=
torch
.
library
.
custom_op
(
"lora::sgmv_expand"
,
_sgmv_expand
,
mutates_args
=
[
"output_tensor"
])
vllm/lora/ops/sgmv_expand_slice.py
View file @
9f698563
...
...
@@ -103,7 +103,7 @@ def _sgmv_expand_slice_kernel(
@
torch
.
inference_mode
()
def
sgmv_expand_slice
(
def
_
sgmv_expand_slice
(
inputs
:
torch
.
Tensor
,
lora_b_weights
:
torch
.
Tensor
,
output_tensor
:
torch
.
Tensor
,
...
...
@@ -115,7 +115,7 @@ def sgmv_expand_slice(
slice_offset
:
int
,
slice_size
:
int
,
add_inputs
:
bool
=
False
,
):
)
->
None
:
"""_summary_
Args:
...
...
@@ -203,3 +203,8 @@ def sgmv_expand_slice(
CAST_TYPE
,
)
return
sgmv_expand_slice
=
torch
.
library
.
custom_op
(
"lora::sgmv_expand_slice"
,
_sgmv_expand_slice
,
mutates_args
=
[
"output_tensor"
])
vllm/lora/ops/sgmv_shrink.py
View file @
9f698563
...
...
@@ -101,7 +101,7 @@ def _sgmv_shrink_kernel(
@
torch
.
inference_mode
()
def
sgmv_shrink
(
def
_
sgmv_shrink
(
inputs
:
torch
.
Tensor
,
lora_a_weights
:
torch
.
Tensor
,
output_tensor
:
torch
.
Tensor
,
...
...
@@ -111,7 +111,7 @@ def sgmv_shrink(
batches
:
int
,
max_seq_length
:
int
,
scaling
:
float
,
):
)
->
None
:
"""
Args:
...
...
@@ -187,3 +187,8 @@ def sgmv_shrink(
SPLIT_K
,
)
return
sgmv_shrink
=
torch
.
library
.
custom_op
(
"lora::sgmv_shrink"
,
_sgmv_shrink
,
mutates_args
=
[
"output_tensor"
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment