Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2385b60d
Unverified
Commit
2385b60d
authored
Nov 22, 2024
by
Jee Jee Li
Committed by
GitHub
Nov 21, 2024
Browse files
[Kernel] Register punica ops directly (#10522)
Signed-off-by:
Jee Jee Li
<
pandaleefree@gmail.com
>
parent
da7e702c
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
157 additions
and
24 deletions
+157
-24
tests/lora/test_punica_variation.py
tests/lora/test_punica_variation.py
+17
-6
vllm/lora/ops/bgmv_expand.py
vllm/lora/ops/bgmv_expand.py
+20
-3
vllm/lora/ops/bgmv_expand_slice.py
vllm/lora/ops/bgmv_expand_slice.py
+22
-3
vllm/lora/ops/bgmv_shrink.py
vllm/lora/ops/bgmv_shrink.py
+20
-3
vllm/lora/ops/sgmv_expand.py
vllm/lora/ops/sgmv_expand.py
+26
-3
vllm/lora/ops/sgmv_expand_slice.py
vllm/lora/ops/sgmv_expand_slice.py
+27
-3
vllm/lora/ops/sgmv_shrink.py
vllm/lora/ops/sgmv_shrink.py
+25
-3
No files found.
tests/lora/test_punica_variation.py
View file @
2385b60d
...
...
@@ -6,12 +6,13 @@ maximum ranks.
import
pytest
import
torch
from
vllm.lora.ops.bgmv_expand
import
bgmv_expand
from
vllm.lora.ops.bgmv_expand_slice
import
bgmv_expand_slice
from
vllm.lora.ops.bgmv_shrink
import
bgmv_shrink
from
vllm.lora.ops.sgmv_expand
import
sgmv_expand
from
vllm.lora.ops.sgmv_expand_slice
import
sgmv_expand_slice
from
vllm.lora.ops.sgmv_shrink
import
sgmv_shrink
# Enable custom op register
import
vllm.lora.ops.bgmv_expand
import
vllm.lora.ops.bgmv_expand_slice
import
vllm.lora.ops.bgmv_shrink
import
vllm.lora.ops.sgmv_expand
import
vllm.lora.ops.sgmv_expand_slice
import
vllm.lora.ops.sgmv_shrink
# noqa: F401
from
vllm.platforms
import
current_platform
from
.utils
import
(
generate_data
,
generate_data_for_expand_nslices
,
...
...
@@ -37,6 +38,16 @@ def assert_close(a, b):
torch
.
testing
.
assert_close
(
a
,
b
,
rtol
=
rtol
,
atol
=
atol
)
# Unlike test_punica_sizes.py, we directly utilize custom op for
# testing, which verifies the correct registration of these ops.
bgmv_expand
=
torch
.
ops
.
vllm
.
bgmv_expand
bgmv_expand_slice
=
torch
.
ops
.
vllm
.
bgmv_expand_slice
bgmv_shrink
=
torch
.
ops
.
vllm
.
bgmv_shrink
sgmv_expand
=
torch
.
ops
.
vllm
.
sgmv_expand
sgmv_expand_slice
=
torch
.
ops
.
vllm
.
sgmv_expand_slice
sgmv_shrink
=
torch
.
ops
.
vllm
.
sgmv_shrink
@
pytest
.
mark
.
parametrize
(
"batches"
,
BATCHES
)
@
pytest
.
mark
.
parametrize
(
"num_loras"
,
NUM_LORA
)
@
pytest
.
mark
.
parametrize
(
"rank"
,
MAX_RANKS
)
...
...
vllm/lora/ops/bgmv_expand.py
View file @
2385b60d
...
...
@@ -9,6 +9,8 @@ import torch
import
triton
import
triton.language
as
tl
from
vllm.utils
import
direct_register_custom_op
from
.utils
import
get_lora_op_configs
...
...
@@ -162,9 +164,24 @@ def _bgmv_expand(
return
def
bgmv_expand_fake
(
inputs
:
torch
.
Tensor
,
lora_b_weights
:
torch
.
Tensor
,
output_tensor
:
torch
.
Tensor
,
lora_indices_tensor
:
torch
.
Tensor
,
add_inputs
:
bool
=
True
,
)
->
None
:
return
try
:
bgmv_expand
=
torch
.
library
.
custom_op
(
"lora::bgmv_expand"
,
_bgmv_expand
,
mutates_args
=
[
"output_tensor"
])
direct_register_custom_op
(
op_name
=
"bgmv_expand"
,
op_func
=
_bgmv_expand
,
mutates_args
=
[
"output_tensor"
],
fake_impl
=
bgmv_expand_fake
,
)
bgmv_expand
=
torch
.
ops
.
vllm
.
bgmv_expand
except
AttributeError
:
bgmv_expand
=
_bgmv_expand
vllm/lora/ops/bgmv_expand_slice.py
View file @
2385b60d
...
...
@@ -9,6 +9,8 @@ import torch
import
triton
import
triton.language
as
tl
from
vllm.utils
import
direct_register_custom_op
from
.utils
import
get_lora_op_configs
...
...
@@ -179,9 +181,26 @@ def _bgmv_expand_slice(
return
def
bgmv_expand_slice_fake
(
inputs
:
torch
.
Tensor
,
lora_b_weights
:
torch
.
Tensor
,
output_tensor
:
torch
.
Tensor
,
lora_indices_tensor
:
torch
.
Tensor
,
slice_offset
:
int
,
slice_size
:
int
,
add_inputs
:
bool
=
True
,
)
->
None
:
return
try
:
bgmv_expand_slice
=
torch
.
library
.
custom_op
(
"lora::bgmv_expand_slice"
,
_bgmv_expand_slice
,
mutates_args
=
[
"output_tensor"
])
direct_register_custom_op
(
op_name
=
"bgmv_expand_slice"
,
op_func
=
_bgmv_expand_slice
,
mutates_args
=
[
"output_tensor"
],
fake_impl
=
bgmv_expand_slice_fake
,
)
bgmv_expand_slice
=
torch
.
ops
.
vllm
.
bgmv_expand_slice
except
AttributeError
:
bgmv_expand_slice
=
_bgmv_expand_slice
vllm/lora/ops/bgmv_shrink.py
View file @
2385b60d
...
...
@@ -9,6 +9,8 @@ import torch
import
triton
import
triton.language
as
tl
from
vllm.utils
import
direct_register_custom_op
from
.utils
import
get_lora_op_configs
...
...
@@ -142,9 +144,24 @@ def _bgmv_shrink(
return
def
bgmv_shrink_fake
(
inputs
:
torch
.
Tensor
,
lora_a_weights
:
torch
.
Tensor
,
output_tensor
:
torch
.
Tensor
,
lora_indices_tensor
:
torch
.
Tensor
,
scaling
:
float
=
1.0
,
)
->
None
:
return
try
:
bgmv_shrink
=
torch
.
library
.
custom_op
(
"lora::bgmv_shrink"
,
_bgmv_shrink
,
mutates_args
=
[
"output_tensor"
])
direct_register_custom_op
(
op_name
=
"bgmv_shrink"
,
op_func
=
_bgmv_shrink
,
mutates_args
=
[
"output_tensor"
],
fake_impl
=
bgmv_shrink_fake
,
)
bgmv_shrink
=
torch
.
ops
.
vllm
.
bgmv_shrink
except
AttributeError
:
bgmv_shrink
=
_bgmv_shrink
vllm/lora/ops/sgmv_expand.py
View file @
2385b60d
...
...
@@ -9,6 +9,8 @@ import torch
import
triton
import
triton.language
as
tl
from
vllm.utils
import
direct_register_custom_op
@
triton
.
jit
def
_sgmv_expand_kernel
(
...
...
@@ -196,9 +198,30 @@ def _sgmv_expand(
return
def
sgmv_expand_fake
(
inputs
:
torch
.
Tensor
,
lora_b_weights
:
torch
.
Tensor
,
output_tensor
:
torch
.
Tensor
,
b_seq_start_loc
:
torch
.
Tensor
,
seq_len_tensor
:
torch
.
Tensor
,
lora_indices_tensor
:
torch
.
Tensor
,
batches
:
int
,
max_seq_length
:
int
,
token_nums
:
int
,
add_inputs
:
bool
=
False
,
)
->
None
:
return
try
:
sgmv_expand
=
torch
.
library
.
custom_op
(
"lora::sgmv_expand"
,
_sgmv_expand
,
mutates_args
=
[
"output_tensor"
])
direct_register_custom_op
(
op_name
=
"sgmv_expand"
,
op_func
=
_sgmv_expand
,
mutates_args
=
[
"output_tensor"
],
fake_impl
=
sgmv_expand_fake
,
)
sgmv_expand
=
torch
.
ops
.
vllm
.
sgmv_expand
except
AttributeError
:
sgmv_expand
=
_sgmv_expand
vllm/lora/ops/sgmv_expand_slice.py
View file @
2385b60d
...
...
@@ -9,6 +9,8 @@ import torch
import
triton
import
triton.language
as
tl
from
vllm.utils
import
direct_register_custom_op
@
triton
.
jit
def
_sgmv_expand_slice_kernel
(
...
...
@@ -209,9 +211,31 @@ def _sgmv_expand_slice(
return
def
sgmv_expand_slice_fake
(
inputs
:
torch
.
Tensor
,
lora_b_weights
:
torch
.
Tensor
,
output_tensor
:
torch
.
Tensor
,
b_seq_start_loc
:
torch
.
Tensor
,
seq_len_tensor
:
torch
.
Tensor
,
lora_indices_tensor
:
torch
.
Tensor
,
batches
:
int
,
max_seq_length
:
int
,
token_nums
:
int
,
slice_offset
:
int
,
slice_size
:
int
,
add_inputs
:
bool
=
False
,
)
->
None
:
return
try
:
sgmv_expand_slice
=
torch
.
library
.
custom_op
(
"lora::sgmv_expand_slice"
,
_sgmv_expand_slice
,
mutates_args
=
[
"output_tensor"
])
direct_register_custom_op
(
op_name
=
"sgmv_expand_slice"
,
op_func
=
_sgmv_expand_slice
,
mutates_args
=
[
"output_tensor"
],
fake_impl
=
sgmv_expand_slice_fake
,
)
sgmv_expand_slice
=
torch
.
ops
.
vllm
.
sgmv_expand_slice
except
AttributeError
:
sgmv_expand_slice
=
_sgmv_expand_slice
vllm/lora/ops/sgmv_shrink.py
View file @
2385b60d
...
...
@@ -9,6 +9,8 @@ import torch
import
triton
import
triton.language
as
tl
from
vllm.utils
import
direct_register_custom_op
@
triton
.
jit
def
_sgmv_shrink_kernel
(
...
...
@@ -190,9 +192,29 @@ def _sgmv_shrink(
return
def
sgmv_shrink_fake
(
inputs
:
torch
.
Tensor
,
lora_a_weights
:
torch
.
Tensor
,
output_tensor
:
torch
.
Tensor
,
b_seq_start_loc
:
torch
.
Tensor
,
seq_len_tensor
:
torch
.
Tensor
,
lora_indices_tensor
:
torch
.
Tensor
,
batches
:
int
,
max_seq_length
:
int
,
token_nums
:
int
,
scaling
:
float
,
)
->
None
:
return
try
:
sgmv_shrink
=
torch
.
library
.
custom_op
(
"lora::sgmv_shrink"
,
_sgmv_shrink
,
mutates_args
=
[
"output_tensor"
])
direct_register_custom_op
(
op_name
=
"sgmv_shrink"
,
op_func
=
_sgmv_shrink
,
mutates_args
=
[
"output_tensor"
],
fake_impl
=
sgmv_shrink_fake
,
)
sgmv_shrink
=
torch
.
ops
.
vllm
.
sgmv_shrink
except
AttributeError
:
sgmv_shrink
=
_sgmv_shrink
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment