Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
36e4acd0
Unverified
Commit
36e4acd0
authored
Nov 11, 2024
by
Jee Jee Li
Committed by
GitHub
Nov 11, 2024
Browse files
[LoRA][Kernel] Remove the unused libentry module (#10214)
Signed-off-by:
Jee Jee Li
<
pandaleefree@gmail.com
>
parent
58170d65
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
49 additions
and
276 deletions
+49
-276
tests/lora/test_punica_sizes.py
tests/lora/test_punica_sizes.py
+24
-49
tests/lora/test_punica_variation.py
tests/lora/test_punica_variation.py
+24
-49
vllm/lora/ops/sgmv_expand.py
vllm/lora/ops/sgmv_expand.py
+0
-3
vllm/lora/ops/sgmv_expand_slice.py
vllm/lora/ops/sgmv_expand_slice.py
+0
-3
vllm/lora/ops/sgmv_shrink.py
vllm/lora/ops/sgmv_shrink.py
+0
-3
vllm/triton_utils/__init__.py
vllm/triton_utils/__init__.py
+1
-2
vllm/triton_utils/libentry.py
vllm/triton_utils/libentry.py
+0
-167
No files found.
tests/lora/test_punica_sizes.py
View file @
36e4acd0
...
@@ -4,8 +4,6 @@ hidden_sizes included in the LoRA models currently supported by vLLM. It tests
...
@@ -4,8 +4,6 @@ hidden_sizes included in the LoRA models currently supported by vLLM. It tests
whether the corresponding Triton kernel can run normally when tensor parallelism
whether the corresponding Triton kernel can run normally when tensor parallelism
is set to [1, 2, 4, 8, 16, 32, 64].
is set to [1, 2, 4, 8, 16, 32, 64].
"""
"""
from
unittest.mock
import
patch
import
pytest
import
pytest
import
torch
import
torch
...
@@ -16,7 +14,6 @@ from vllm.lora.ops.sgmv_expand import sgmv_expand
...
@@ -16,7 +14,6 @@ from vllm.lora.ops.sgmv_expand import sgmv_expand
from
vllm.lora.ops.sgmv_expand_slice
import
sgmv_expand_slice
from
vllm.lora.ops.sgmv_expand_slice
import
sgmv_expand_slice
from
vllm.lora.ops.sgmv_shrink
import
sgmv_shrink
from
vllm.lora.ops.sgmv_shrink
import
sgmv_shrink
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.triton_utils.libentry
import
LibEntry
from
.utils
import
(
generate_data
,
generate_data_for_expand_nslices
,
from
.utils
import
(
generate_data
,
generate_data_for_expand_nslices
,
ref_torch_groupgemm
)
ref_torch_groupgemm
)
...
@@ -235,9 +232,6 @@ def test_punica_bgmv(
...
@@ -235,9 +232,6 @@ def test_punica_bgmv(
seed
:
int
,
seed
:
int
,
device
:
str
,
device
:
str
,
):
):
from
vllm.lora.ops.bgmv_expand
import
_bgmv_expand_kernel
from
vllm.lora.ops.bgmv_shrink
import
_bgmv_shrink_kernel
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
current_platform
.
seed_everything
(
seed
)
current_platform
.
seed_everything
(
seed
)
...
@@ -262,33 +256,21 @@ def test_punica_bgmv(
...
@@ -262,33 +256,21 @@ def test_punica_bgmv(
device
,
device
,
)
)
if
op_type
==
"shrink"
:
if
op_type
==
"shrink"
:
# The current _bgmv_shrink_kernel does not require the libentry
bgmv_shrink
(
# decoration. The purpose of adding this patch is to test the
inputs_tensor
,
# correctness of libentry.
lora_weights
,
with
patch
(
our_out_tensor
,
"vllm.lora.ops.bgmv_shrink._bgmv_shrink_kernel"
,
indices
,
LibEntry
(
_bgmv_shrink_kernel
),
scaling
,
):
)
bgmv_shrink
(
inputs_tensor
,
lora_weights
,
our_out_tensor
,
indices
,
scaling
,
)
else
:
else
:
# ditto
bgmv_expand
(
with
patch
(
inputs_tensor
,
"vllm.lora.ops.bgmv_expand._bgmv_expand_kernel"
,
lora_weights
,
LibEntry
(
_bgmv_expand_kernel
),
our_out_tensor
,
):
indices
,
bgmv_expand
(
add_inputs
=
True
,
inputs_tensor
,
)
lora_weights
,
our_out_tensor
,
indices
,
add_inputs
=
True
,
)
ref_torch_groupgemm
(
ref_torch_groupgemm
(
ref_out_tensor
,
ref_out_tensor
,
inputs_tensor
,
inputs_tensor
,
...
@@ -324,7 +306,6 @@ def test_punica_expand_nslices(
...
@@ -324,7 +306,6 @@ def test_punica_expand_nslices(
seed
:
int
,
seed
:
int
,
device
:
str
,
device
:
str
,
):
):
from
vllm.lora.ops.bgmv_expand_slice
import
_bgmv_expand_slice_kernel
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
current_platform
.
seed_everything
(
seed
)
current_platform
.
seed_everything
(
seed
)
...
@@ -374,22 +355,16 @@ def test_punica_expand_nslices(
...
@@ -374,22 +355,16 @@ def test_punica_expand_nslices(
add_inputs
=
True
,
add_inputs
=
True
,
)
)
else
:
else
:
# The current _bgmv_expand_slice_kernel does not require the
# libentry decoration. The purpose of adding this patch is to test
bgmv_expand_slice
(
# the correctness of libentry.
inputs_tensor
,
with
patch
(
lora_weights
,
"vllm.lora.ops.bgmv_expand_slice._bgmv_expand_slice_kernel"
,
our_outputs
,
LibEntry
(
_bgmv_expand_slice_kernel
),
indices
,
):
slice_offset
,
bgmv_expand_slice
(
slice_size
=
hidden_size
,
inputs_tensor
,
add_inputs
=
True
,
lora_weights
,
)
our_outputs
,
indices
,
slice_offset
,
slice_size
=
hidden_size
,
add_inputs
=
True
,
)
ref_torch_groupgemm
(
ref_torch_groupgemm
(
ref_outputs
[:,
slice_offset
:
slice_offset
+
hidden_size
],
ref_outputs
[:,
slice_offset
:
slice_offset
+
hidden_size
],
inputs_tensor
,
inputs_tensor
,
...
...
tests/lora/test_punica_variation.py
View file @
36e4acd0
...
@@ -3,8 +3,6 @@ This script is mainly used to test whether trtion kernels can run normally
...
@@ -3,8 +3,6 @@ This script is mainly used to test whether trtion kernels can run normally
under different conditions, including various batches, numbers of LoRA , and
under different conditions, including various batches, numbers of LoRA , and
maximum ranks.
maximum ranks.
"""
"""
from
unittest.mock
import
patch
import
pytest
import
pytest
import
torch
import
torch
...
@@ -15,7 +13,6 @@ from vllm.lora.ops.sgmv_expand import sgmv_expand
...
@@ -15,7 +13,6 @@ from vllm.lora.ops.sgmv_expand import sgmv_expand
from
vllm.lora.ops.sgmv_expand_slice
import
sgmv_expand_slice
from
vllm.lora.ops.sgmv_expand_slice
import
sgmv_expand_slice
from
vllm.lora.ops.sgmv_shrink
import
sgmv_shrink
from
vllm.lora.ops.sgmv_shrink
import
sgmv_shrink
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.triton_utils.libentry
import
LibEntry
from
.utils
import
(
generate_data
,
generate_data_for_expand_nslices
,
from
.utils
import
(
generate_data
,
generate_data_for_expand_nslices
,
ref_torch_groupgemm
)
ref_torch_groupgemm
)
...
@@ -150,8 +147,6 @@ def test_punica_bgmv(
...
@@ -150,8 +147,6 @@ def test_punica_bgmv(
seed
:
int
,
seed
:
int
,
device
:
str
,
device
:
str
,
):
):
from
vllm.lora.ops.bgmv_expand
import
_bgmv_expand_kernel
from
vllm.lora.ops.bgmv_shrink
import
_bgmv_shrink_kernel
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
current_platform
.
seed_everything
(
seed
)
current_platform
.
seed_everything
(
seed
)
...
@@ -177,33 +172,22 @@ def test_punica_bgmv(
...
@@ -177,33 +172,22 @@ def test_punica_bgmv(
device
,
device
,
)
)
if
op_type
==
"shrink"
:
if
op_type
==
"shrink"
:
# The current _bgmv_shrink_kernel does not require the libentry
bgmv_shrink
(
# decoration. The purpose of adding this patch is to test the
inputs_tensor
,
# correctness of libentry.
lora_weights
,
with
patch
(
our_out_tensor
,
"vllm.lora.ops.bgmv_shrink._bgmv_shrink_kernel"
,
indices
,
LibEntry
(
_bgmv_shrink_kernel
),
scaling
,
):
)
bgmv_shrink
(
inputs_tensor
,
lora_weights
,
our_out_tensor
,
indices
,
scaling
,
)
else
:
else
:
# ditto
with
patch
(
bgmv_expand
(
"vllm.lora.ops.bgmv_expand._bgmv_expand_kernel"
,
inputs_tensor
,
LibEntry
(
_bgmv_expand_kernel
),
lora_weights
,
):
our_out_tensor
,
bgmv_expand
(
indices
,
inputs_tensor
,
add_inputs
=
True
,
lora_weights
,
)
our_out_tensor
,
indices
,
add_inputs
=
True
,
)
ref_torch_groupgemm
(
ref_torch_groupgemm
(
ref_out_tensor
,
ref_out_tensor
,
inputs_tensor
,
inputs_tensor
,
...
@@ -239,8 +223,6 @@ def test_punica_expand_nslices(
...
@@ -239,8 +223,6 @@ def test_punica_expand_nslices(
seed
:
int
,
seed
:
int
,
device
:
str
,
device
:
str
,
):
):
from
vllm.lora.ops.bgmv_expand_slice
import
_bgmv_expand_slice_kernel
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
current_platform
.
seed_everything
(
seed
)
current_platform
.
seed_everything
(
seed
)
...
@@ -289,22 +271,15 @@ def test_punica_expand_nslices(
...
@@ -289,22 +271,15 @@ def test_punica_expand_nslices(
add_inputs
=
True
,
add_inputs
=
True
,
)
)
else
:
else
:
# The current _bgmv_expand_slice_kernel does not require the
bgmv_expand_slice
(
# libentry decoration. The purpose of adding this patch is to test
inputs_tensor
,
# the correctness of libentry.
lora_weights
,
with
patch
(
our_outputs
,
"vllm.lora.ops.bgmv_expand_slice._bgmv_expand_slice_kernel"
,
indices
,
LibEntry
(
_bgmv_expand_slice_kernel
),
slice_offset
,
):
slice_size
=
hidden_size
,
bgmv_expand_slice
(
add_inputs
=
True
,
inputs_tensor
,
)
lora_weights
,
our_outputs
,
indices
,
slice_offset
,
slice_size
=
hidden_size
,
add_inputs
=
True
,
)
ref_torch_groupgemm
(
ref_torch_groupgemm
(
ref_outputs
[:,
slice_offset
:
slice_offset
+
hidden_size
],
ref_outputs
[:,
slice_offset
:
slice_offset
+
hidden_size
],
inputs_tensor
,
inputs_tensor
,
...
...
vllm/lora/ops/sgmv_expand.py
View file @
36e4acd0
...
@@ -9,10 +9,7 @@ import torch
...
@@ -9,10 +9,7 @@ import torch
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
from
vllm.triton_utils
import
libentry
@
libentry
()
@
triton
.
jit
@
triton
.
jit
def
_sgmv_expand_kernel
(
def
_sgmv_expand_kernel
(
input_ptr
,
input_ptr
,
...
...
vllm/lora/ops/sgmv_expand_slice.py
View file @
36e4acd0
...
@@ -9,10 +9,7 @@ import torch
...
@@ -9,10 +9,7 @@ import torch
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
from
vllm.triton_utils
import
libentry
@
libentry
()
@
triton
.
jit
@
triton
.
jit
def
_sgmv_expand_slice_kernel
(
def
_sgmv_expand_slice_kernel
(
input_ptr
,
input_ptr
,
...
...
vllm/lora/ops/sgmv_shrink.py
View file @
36e4acd0
...
@@ -9,10 +9,7 @@ import torch
...
@@ -9,10 +9,7 @@ import torch
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
from
vllm.triton_utils
import
libentry
@
libentry
()
@
triton
.
jit
@
triton
.
jit
def
_sgmv_shrink_kernel
(
def
_sgmv_shrink_kernel
(
input_ptr
,
input_ptr
,
...
...
vllm/triton_utils/__init__.py
View file @
36e4acd0
...
@@ -6,6 +6,5 @@ if HAS_TRITON:
...
@@ -6,6 +6,5 @@ if HAS_TRITON:
from
vllm.triton_utils.custom_cache_manager
import
(
from
vllm.triton_utils.custom_cache_manager
import
(
maybe_set_triton_cache_manager
)
maybe_set_triton_cache_manager
)
from
vllm.triton_utils.libentry
import
libentry
__all__
+=
[
"maybe_set_triton_cache_manager"
,
"libentry"
]
__all__
+=
[
"maybe_set_triton_cache_manager"
]
vllm/triton_utils/libentry.py
deleted
100644 → 0
View file @
58170d65
# Copied From https://github.com/FlagOpen/FlagGems
import
inspect
import
triton
class
LibEntry
(
triton
.
KernelInterface
):
def
__init__
(
self
,
fn
,
):
self
.
fn
=
fn
self
.
arg_names
=
fn
.
arg_names
self
.
divisibility
=
16
self
.
kernel_cache
=
dict
()
fn
=
self
.
fn
while
not
isinstance
(
fn
,
triton
.
runtime
.
JITFunction
):
fn
=
fn
.
fn
self
.
jit_function
:
triton
.
runtime
.
JITFunction
=
fn
self
.
specialize_indices
=
[
p
.
num
for
p
in
self
.
jit_function
.
params
if
not
p
.
is_constexpr
and
not
p
.
do_not_specialize
]
self
.
do_not_specialize_indices
=
[
p
.
num
for
p
in
self
.
jit_function
.
params
if
not
p
.
is_constexpr
and
p
.
do_not_specialize
]
def
key
(
self
,
spec_args
,
dns_args
,
const_args
):
spec_key
=
[(
arg
.
dtype
,
arg
.
data_ptr
()
%
self
.
divisibility
==
0
)
if
hasattr
(
arg
,
"data_ptr"
)
else
(
type
(
arg
),
arg
)
for
arg
in
spec_args
]
dns_key
=
[
arg
.
dtype
if
hasattr
(
arg
,
"data_ptr"
)
else
type
(
arg
)
if
not
isinstance
(
arg
,
int
)
else
"i32"
if
arg
>=
-
(
2
**
31
)
and
arg
<=
2
**
31
-
1
else
"u64"
if
arg
>=
2
**
63
and
arg
<=
2
**
64
-
1
else
"i64"
for
arg
in
dns_args
]
# const args passed by position
return
tuple
(
spec_key
+
dns_key
+
const_args
)
def
run
(
self
,
*
args
,
**
kwargs
):
grid
=
kwargs
[
"grid"
]
# collect all the arguments
spec_args
=
[]
# specialize arguments
dns_args
=
[]
# do not specialize arguments
const_args
=
[]
# constexpr arguments
k_args
=
[]
# kernel arguments
for
i
,
arg
in
enumerate
(
args
):
if
i
in
self
.
specialize_indices
:
k_args
.
append
(
arg
)
spec_args
.
append
(
arg
)
elif
i
in
self
.
do_not_specialize_indices
:
k_args
.
append
(
arg
)
dns_args
.
append
(
arg
)
else
:
const_args
.
append
(
arg
)
for
p
in
self
.
jit_function
.
params
[
len
(
args
):]:
if
p
.
name
in
kwargs
:
val
=
kwargs
[
p
.
name
]
elif
p
.
default
is
inspect
.
_empty
:
continue
else
:
val
=
p
.
default
if
p
.
is_constexpr
:
const_args
.
append
(
val
)
elif
p
.
do_not_specialize
:
dns_args
.
append
(
val
)
k_args
.
append
(
val
)
else
:
spec_args
.
append
(
val
)
k_args
.
append
(
val
)
entry_key
=
self
.
key
(
spec_args
,
dns_args
,
const_args
)
if
entry_key
not
in
self
.
kernel_cache
:
# compile the kernel also completes the related computations
kernel
=
self
.
fn
.
run
(
*
args
,
**
kwargs
)
fn
=
self
.
fn
# collect constexpr arguments for grid computation
constexprs
=
{}
while
not
isinstance
(
fn
,
triton
.
runtime
.
JITFunction
):
if
isinstance
(
fn
,
triton
.
runtime
.
Autotuner
):
config
=
fn
.
best_config
constexprs
[
"num_warps"
]
=
config
.
num_warps
constexprs
[
"num_stages"
]
=
config
.
num_stages
constexprs
[
"num_ctas"
]
=
config
.
num_ctas
constexprs
=
{
**
constexprs
,
**
config
.
kwargs
}
elif
isinstance
(
fn
,
triton
.
runtime
.
Heuristics
):
for
v
,
heur
in
fn
.
values
.
items
():
constexprs
[
v
]
=
heur
({
**
dict
(
zip
(
fn
.
arg_names
,
args
)),
**
kwargs
,
**
constexprs
,
})
else
:
raise
RuntimeError
(
"Invalid Runtime Function"
)
fn
=
fn
.
fn
# In vLLM, certain kernels like fused_moe_kernel get the
# best_config(as kwargs) from a configuration json file, rather
# than using Autotuner & Heuristics. Therefore, all their constexprs
# (tl.constexpr) are assigned values through the following loop.
for
p
in
self
.
jit_function
.
params
:
if
p
.
is_constexpr
and
p
.
name
not
in
constexprs
:
constexprs
[
p
.
name
]
=
p
.
default
#default=inspect._empty
self
.
kernel_cache
[
entry_key
]
=
(
kernel
,
constexprs
)
else
:
# load kernel from cache directly
kernel
,
constexprs
=
self
.
kernel_cache
[
entry_key
]
if
callable
(
grid
):
# collect all arguments to the grid fn,ie:
# 1. args,
# 2. kwargs,
# 3. all all other captured arguments in CompiledKernel from
# Autotunner & Heuristics when kwargs & captured args conflict,
# captured args have higher priority
# 4. We must filter out captured args with default value firstly
constexprs
=
{
k
:
v
for
k
,
v
in
constexprs
.
items
()
if
v
is
not
inspect
.
_empty
}
meta
=
{
**
dict
(
zip
(
self
.
arg_names
,
args
)),
**
kwargs
,
**
constexprs
,
}
grid
=
grid
(
meta
)
if
isinstance
(
grid
,
tuple
):
grid
=
grid
+
(
1
,
1
)
elif
isinstance
(
grid
,
list
):
grid
=
grid
+
[
1
,
1
]
kernel
[
grid
[
0
:
3
]](
*
k_args
)
# maintaining the same return type as the JITFunction.run
return
kernel
def
libentry
():
"""
Decorator for triton library entries.
Motivation:
The runtime overhead of Triton kernels is the reason for the lower
performance of small kernels, particularly evident with smaller models.
Using this decorator can reduce Triton runtime overhead.
How:
The `run` function of JITFunction needs to accomplish:
- Parameter binding using inspect
- KernelArg type wrapping
- Cache key calculation
When dealing with small size, these steps can become bottlenecks in
Triton runtime. Libentry simplifies these steps to reduce runtime
overhead, thereby improving the runtime expenses of small kernels.
NOTE:
When Triton is upgraded to version 3.0.0, libentry can be removed,
see: https://github.com/vllm-project/vllm/pull/5036#issuecomment-2243396245
"""
def
decorator
(
fn
):
return
LibEntry
(
fn
)
return
decorator
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment