Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
300da091
Unverified
Commit
300da091
authored
Sep 25, 2024
by
bnellnm
Committed by
GitHub
Sep 25, 2024
Browse files
[Kernel] Fullgraph and opcheck tests (#8479)
parent
1c046447
Changes
26
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
157 additions
and
38 deletions
+157
-38
tests/kernels/test_rotary_embedding.py
tests/kernels/test_rotary_embedding.py
+62
-0
tests/kernels/test_utils.py
tests/kernels/test_utils.py
+24
-0
tests/kernels/utils.py
tests/kernels/utils.py
+37
-6
vllm/_custom_ops.py
vllm/_custom_ops.py
+31
-30
vllm/model_executor/layers/mamba/ops/mamba_ssm.py
vllm/model_executor/layers/mamba/ops/mamba_ssm.py
+2
-2
vllm/model_executor/layers/quantization/gptq.py
vllm/model_executor/layers/quantization/gptq.py
+1
-0
No files found.
tests/kernels/test_rotary_embedding.py
0 → 100644
View file @
300da091
"""
Tests for miscellaneous utilities
"""
from
typing
import
Optional
import
pytest
import
torch
from
tests.kernels.utils
import
opcheck
from
vllm.model_executor.layers.rotary_embedding
import
RotaryEmbedding
def
rotary_embedding_opcheck
(
rot
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
):
cos_sin_cache
=
rot
.
cos_sin_cache
.
to
(
query
.
device
,
dtype
=
query
.
dtype
)
# ops.rotary_embedding()/batched_rotary_embedding()
# are in-place operations that update the query and key tensors.
if
offsets
is
not
None
:
opcheck
(
torch
.
ops
.
_C
.
batched_rotary_embedding
,
(
positions
,
query
,
key
,
rot
.
head_size
,
cos_sin_cache
,
rot
.
is_neox_style
,
rot
.
rotary_dim
,
offsets
))
else
:
opcheck
(
torch
.
ops
.
_C
.
rotary_embedding
,
(
positions
,
query
,
key
,
rot
.
head_size
,
cos_sin_cache
,
rot
.
is_neox_style
))
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cuda"
])
@
pytest
.
mark
.
parametrize
(
"max_position"
,
[
11
,
4096
,
32768
])
@
pytest
.
mark
.
parametrize
(
"is_neox_style"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"rotary_dim"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"head_size"
,
[
32
,
108
])
@
pytest
.
mark
.
parametrize
(
"seq_len"
,
[
11
,
1024
])
def
test_rotary_embedding_opcheck
(
dist_init
,
device
,
max_position
,
is_neox_style
,
rotary_dim
,
head_size
,
seq_len
):
batch_size
=
1
base
=
0
num_heads
=
7
rot
=
RotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
torch
.
float32
)
positions
=
torch
.
randint
(
0
,
max_position
,
(
batch_size
,
seq_len
),
device
=
device
)
query
=
torch
.
randn
(
batch_size
,
seq_len
,
num_heads
*
head_size
,
dtype
=
torch
.
float32
,
device
=
device
)
key
=
torch
.
randn_like
(
query
)
rotary_embedding_opcheck
(
rot
,
positions
,
query
,
key
)
offsets
=
torch
.
zeros
(
batch_size
*
seq_len
,
device
=
device
,
dtype
=
torch
.
long
)
rotary_embedding_opcheck
(
rot
,
positions
,
query
,
key
,
offsets
)
tests/kernels/test_utils.py
0 → 100644
View file @
300da091
"""
Tests for miscellaneous utilities
"""
import
pytest
import
torch
from
tests.kernels.utils
import
opcheck
from
vllm.platforms
import
current_platform
def
test_convert_fp8_opcheck
():
data
=
torch
.
randn
((
256
,
256
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
result
=
torch
.
empty_like
(
data
,
dtype
=
torch
.
float8_e4m3fn
)
opcheck
(
torch
.
ops
.
_C_cache_ops
.
convert_fp8
,
(
result
,
data
,
1.0
,
"fp8"
))
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda
(),
reason
=
"Only supported for CUDA"
)
def
test_cuda_utils_opcheck
():
opcheck
(
torch
.
ops
.
_C_cuda_utils
.
get_device_attribute
,
(
0
,
0
))
opcheck
(
torch
.
ops
.
_C_cuda_utils
.
get_max_shared_memory_per_block_device_attribute
,
(
0
,
))
tests/kernels/utils.py
View file @
300da091
...
@@ -2,12 +2,14 @@
...
@@ -2,12 +2,14 @@
import
itertools
import
itertools
import
random
import
random
import
unittest
from
numbers
import
Number
from
numbers
import
Number
from
typing
import
(
Any
,
Dict
,
List
,
NamedTuple
,
Optional
,
Sequence
,
Tuple
,
from
typing
import
(
Any
,
Dict
,
List
,
NamedTuple
,
Optional
,
Sequence
,
Tuple
,
Union
)
Union
)
import
pytest
import
pytest
import
torch
import
torch
from
torch._prims_common
import
TensorLikeType
from
vllm.attention
import
AttentionBackend
,
AttentionMetadata
,
AttentionType
from
vllm.attention
import
AttentionBackend
,
AttentionMetadata
,
AttentionType
from
vllm.utils
import
(
STR_BACKEND_ENV_VAR
,
STR_XFORMERS_ATTN_VAL
,
from
vllm.utils
import
(
STR_BACKEND_ENV_VAR
,
STR_XFORMERS_ATTN_VAL
,
...
@@ -946,6 +948,34 @@ def assert_actual_matches_ideal(test_params: PhaseTestParameters,
...
@@ -946,6 +948,34 @@ def assert_actual_matches_ideal(test_params: PhaseTestParameters,
output_under_test
.
view_as
(
ideal_output
))
output_under_test
.
view_as
(
ideal_output
))
# Copied/modified from torch._refs.__init__.py
def
fp8_allclose
(
a
:
TensorLikeType
,
b
:
TensorLikeType
,
rtol
:
float
=
1e-05
,
atol
:
float
=
1e-08
,
equal_nan
:
bool
=
False
,
)
->
bool
:
"""
Reference implementation of torch.allclose
"""
torch
.
_refs
.
_check_close_args
(
name
=
"torch.allclose"
,
a
=
a
,
b
=
b
,
rtol
=
rtol
,
atol
=
atol
)
return
bool
(
torch
.
all
(
torch
.
isclose
(
a
.
double
(),
b
.
double
(),
rtol
=
rtol
,
atol
=
atol
,
equal_nan
=
equal_nan
)).
item
())
# A special version of op check that has a restricted default set of test_utils
# and a patched version of allclose that supports fp8 types.
def
opcheck
(
op
:
Union
[
torch
.
_ops
.
OpOverload
,
torch
.
_ops
.
OpOverloadPacket
,
def
opcheck
(
op
:
Union
[
torch
.
_ops
.
OpOverload
,
torch
.
_ops
.
OpOverloadPacket
,
torch
.
_library
.
custom_ops
.
CustomOpDef
],
torch
.
_library
.
custom_ops
.
CustomOpDef
],
args
:
Tuple
[
Any
,
...],
args
:
Tuple
[
Any
,
...],
...
@@ -954,6 +984,7 @@ def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket,
...
@@ -954,6 +984,7 @@ def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket,
test_utils
:
Union
[
str
,
Sequence
[
str
]]
=
ALL_OPCHECK_TEST_UTILS
,
test_utils
:
Union
[
str
,
Sequence
[
str
]]
=
ALL_OPCHECK_TEST_UTILS
,
raise_exception
:
bool
=
True
,
raise_exception
:
bool
=
True
,
cond
:
bool
=
True
)
->
Dict
[
str
,
str
]:
cond
:
bool
=
True
)
->
Dict
[
str
,
str
]:
with
unittest
.
mock
.
patch
(
'torch.allclose'
,
new
=
fp8_allclose
):
return
torch
.
library
.
opcheck
(
return
torch
.
library
.
opcheck
(
op
,
op
,
args
,
args
,
...
...
vllm/_custom_ops.py
View file @
300da091
...
@@ -20,8 +20,10 @@ if not current_platform.is_tpu():
...
@@ -20,8 +20,10 @@ if not current_platform.is_tpu():
if
current_platform
.
is_rocm
():
if
current_platform
.
is_rocm
():
import
vllm._rocm_C
# noqa: F401
import
vllm._rocm_C
# noqa: F401
supports_moe_ops
=
False
with
contextlib
.
suppress
(
ImportError
):
with
contextlib
.
suppress
(
ImportError
):
import
vllm._moe_C
# noqa: F401
import
vllm._moe_C
# noqa: F401
supports_moe_ops
=
True
def
hint_on_error
(
fn
):
def
hint_on_error
(
fn
):
...
@@ -253,9 +255,7 @@ def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
...
@@ -253,9 +255,7 @@ def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_g_idx
,
use_exllama
,
bit
)
b_g_idx
,
use_exllama
,
bit
)
# TODO: has to be a better way to do this
if
hasattr
(
torch
.
ops
.
_C
,
"gptq_gemm"
):
try
:
torch
.
ops
.
_C
.
gptq_gemm
# noqa B018
@
torch
.
library
.
register_fake
(
"_C::gptq_gemm"
)
@
torch
.
library
.
register_fake
(
"_C::gptq_gemm"
)
def
_gptq_gemm_fake
(
a
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
def
_gptq_gemm_fake
(
a
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
...
@@ -265,8 +265,6 @@ try:
...
@@ -265,8 +265,6 @@ try:
return
torch
.
empty
((
a
.
size
(
0
),
b_q_weight
.
size
(
1
)),
return
torch
.
empty
((
a
.
size
(
0
),
b_q_weight
.
size
(
1
)),
dtype
=
a
.
dtype
,
dtype
=
a
.
dtype
,
device
=
a
.
device
)
device
=
a
.
device
)
except
Exception
:
pass
def
gptq_shuffle
(
q_weight
:
torch
.
Tensor
,
q_perm
:
torch
.
Tensor
,
def
gptq_shuffle
(
q_weight
:
torch
.
Tensor
,
q_perm
:
torch
.
Tensor
,
...
@@ -292,9 +290,7 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
...
@@ -292,9 +290,7 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
size_n
,
size_k
)
size_n
,
size_k
)
# TODO: has to be a better way to do this
if
hasattr
(
torch
.
ops
.
_C
,
"gptq_marlin_24_gemm"
):
try
:
torch
.
ops
.
_C
.
gptq_marlin_24_gemm
# noqa B018
@
torch
.
library
.
register_fake
(
"_C::gptq_marlin_24_gemm"
)
@
torch
.
library
.
register_fake
(
"_C::gptq_marlin_24_gemm"
)
def
_gptq_marlin_24_gemm_fake
(
a
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
def
_gptq_marlin_24_gemm_fake
(
a
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
...
@@ -420,8 +416,8 @@ try:
...
@@ -420,8 +416,8 @@ try:
@
torch
.
library
.
register_fake
(
"_C::machete_gemm"
)
@
torch
.
library
.
register_fake
(
"_C::machete_gemm"
)
def
machete_gemm_fake
(
def
machete_gemm_fake
(
a
:
torch
.
Tensor
,
a
:
torch
.
Tensor
,
b_q
:
torch
.
# Should be the tensor returned by machete_prepack_B
Tensor
,
# Should be the tensor returned by machete_prepack_B
b_q
:
torch
.
Tensor
,
b_type
:
ScalarType
,
b_type
:
ScalarType
,
b_scales
:
Optional
[
torch
.
Tensor
]
=
None
,
b_scales
:
Optional
[
torch
.
Tensor
]
=
None
,
b_zeros
:
Optional
[
torch
.
Tensor
]
=
None
,
b_zeros
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -451,10 +447,10 @@ try:
...
@@ -451,10 +447,10 @@ try:
return
torch
.
empty_like
(
x
)
return
torch
.
empty_like
(
x
)
@
torch
.
library
.
register_fake
(
"_C::causal_conv1d_update"
)
@
torch
.
library
.
register_fake
(
"_C::causal_conv1d_update"
)
def
causal_conv1d_update_fake
(
x
:
torch
.
Tensor
,
conv_state
:
torch
.
Tensor
,
def
causal_conv1d_update_fake
(
weight
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
conv_state
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias_
:
Optional
[
torch
.
Tensor
],
bias_
:
Optional
[
torch
.
Tensor
],
silu_activation
:
bool
,
silu_activation
:
bool
)
->
torch
.
Tensor
:
conv_state_indices
:
Optional
[
torch
.
Tensor
]
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
x
)
return
torch
.
empty_like
(
x
)
@
torch
.
library
.
register_fake
(
"_C::selective_scan_fwd"
)
@
torch
.
library
.
register_fake
(
"_C::selective_scan_fwd"
)
...
@@ -465,20 +461,11 @@ try:
...
@@ -465,20 +461,11 @@ try:
delta_softplus
:
bool
,
index_
:
Optional
[
torch
.
Tensor
],
delta_softplus
:
bool
,
index_
:
Optional
[
torch
.
Tensor
],
x
:
Optional
[
torch
.
Tensor
])
->
List
[
torch
.
Tensor
]:
x
:
Optional
[
torch
.
Tensor
])
->
List
[
torch
.
Tensor
]:
a
=
torch
.
empty_like
(
u
)
a
=
torch
.
empty_like
(
u
)
if
x
is
not
None
:
b
=
x
else
:
b
=
torch
.
empty
((
u
.
size
(
0
),
u
.
size
(
1
),
A
.
size
(
1
)),
dtype
=
u
.
dtype
,
device
=
u
.
device
)
if
z_
is
not
None
:
if
z_
is
not
None
:
c
=
torch
.
empty_like
(
z_
)
c
=
torch
.
empty_like
(
z_
)
return
[
a
,
b
,
c
]
return
[
a
,
c
]
else
:
else
:
return
[
a
,
b
]
return
[
a
]
except
Exception
:
pass
# cutlass
# cutlass
...
@@ -626,16 +613,12 @@ def machete_prepack_B(b_q_weight: torch.Tensor,
...
@@ -626,16 +613,12 @@ def machete_prepack_B(b_q_weight: torch.Tensor,
return
torch
.
ops
.
_C
.
machete_prepack_B
(
b_q_weight
,
b_type
)
return
torch
.
ops
.
_C
.
machete_prepack_B
(
b_q_weight
,
b_type
)
# TODO: has to be a better way to do this
if
hasattr
(
torch
.
ops
.
_C
,
"permute_cols"
):
try
:
torch
.
ops
.
_C
.
permute_cols
# noqa B018
@
torch
.
library
.
register_fake
(
"_C::permute_cols"
)
@
torch
.
library
.
register_fake
(
"_C::permute_cols"
)
def
_permute_cols_fake
(
a
:
torch
.
Tensor
,
def
_permute_cols_fake
(
a
:
torch
.
Tensor
,
perm
:
torch
.
Tensor
)
->
torch
.
Tensor
:
perm
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
a
)
return
torch
.
empty_like
(
a
)
except
Exception
:
pass
def
permute_cols
(
a
:
torch
.
Tensor
,
perm
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
permute_cols
(
a
:
torch
.
Tensor
,
perm
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
@@ -828,6 +811,24 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
...
@@ -828,6 +811,24 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
token_expert_indicies
,
gating_output
)
token_expert_indicies
,
gating_output
)
if
supports_moe_ops
and
hasattr
(
torch
.
ops
.
_moe_C
,
"marlin_gemm_moe"
):
@
torch
.
library
.
register_fake
(
"_moe_C::marlin_gemm_moe"
)
def
marlin_gemm_moe_fake
(
a
:
torch
.
Tensor
,
b_q_weights
:
torch
.
Tensor
,
sorted_ids
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
g_idx
:
torch
.
Tensor
,
perm
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
b_q_type
:
ScalarType
,
size_m
:
int
,
size_n
:
int
,
size_k
:
int
,
is_k_full
:
bool
,
num_experts
:
int
,
topk
:
int
,
moe_block_size
:
int
,
replicate_input
:
bool
,
apply_weights
:
bool
)
->
torch
.
Tensor
:
return
torch
.
empty
((
size_m
,
topk
,
size_n
),
dtype
=
a
.
dtype
,
device
=
a
.
device
)
def
reshape_and_cache
(
def
reshape_and_cache
(
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/mamba/ops/mamba_ssm.py
View file @
300da091
...
@@ -361,7 +361,7 @@ def selective_scan_fn(u,
...
@@ -361,7 +361,7 @@ def selective_scan_fn(u,
x
[:,
:,
0
,
0
::
2
]
=
1
x
[:,
:,
0
,
0
::
2
]
=
1
if
prev_state
is
not
None
:
if
prev_state
is
not
None
:
x
[:,
:,
0
,
1
::
2
].
copy_
(
prev_state
)
x
[:,
:,
0
,
1
::
2
].
copy_
(
prev_state
)
out
,
x
,
*
rest
=
ops
.
selective_scan_fwd
(
u
,
delta
,
A
,
B
,
C
,
D
,
z
,
delta_bias
,
out
,
*
rest
=
ops
.
selective_scan_fwd
(
u
,
delta
,
A
,
B
,
C
,
D
,
z
,
delta_bias
,
delta_softplus
,
position_indices
,
x
)
delta_softplus
,
position_indices
,
x
)
last_state
=
x
[:,
:,
-
1
,
1
::
2
]
# (batch, dim, dstate)
last_state
=
x
[:,
:,
-
1
,
1
::
2
]
# (batch, dim, dstate)
if
z
is
None
:
if
z
is
None
:
...
...
vllm/model_executor/layers/quantization/gptq.py
View file @
300da091
...
@@ -217,6 +217,7 @@ class GPTQLinearMethod(LinearMethodBase):
...
@@ -217,6 +217,7 @@ class GPTQLinearMethod(LinearMethodBase):
layer
.
qzeros
=
Parameter
(
layer
.
qzeros
.
data
,
requires_grad
=
False
)
layer
.
qzeros
=
Parameter
(
layer
.
qzeros
.
data
,
requires_grad
=
False
)
layer
.
qweight
=
Parameter
(
layer
.
qweight
.
data
,
requires_grad
=
False
)
layer
.
qweight
=
Parameter
(
layer
.
qweight
.
data
,
requires_grad
=
False
)
layer
.
g_idx
=
Parameter
(
layer
.
g_idx
.
data
,
requires_grad
=
False
)
layer
.
g_idx
=
Parameter
(
layer
.
g_idx
.
data
,
requires_grad
=
False
)
layer
.
scales
=
Parameter
(
layer
.
scales
.
data
,
requires_grad
=
False
)
# exllama needs to shuffle the weight after the weight is loaded
# exllama needs to shuffle the weight after the weight is loaded
# here we do the shuffle on first forward pass
# here we do the shuffle on first forward pass
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment