Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
bd37b9fb
Unverified
Commit
bd37b9fb
authored
Oct 08, 2024
by
bnellnm
Committed by
GitHub
Oct 08, 2024
Browse files
[Bugfix] Try to handle older versions of pytorch (#9086)
parent
de24046f
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
41 additions
and
21 deletions
+41
-21
tests/kernels/test_awq.py
tests/kernels/test_awq.py
+5
-0
tests/kernels/test_awq_marlin.py
tests/kernels/test_awq_marlin.py
+4
-0
vllm/_custom_ops.py
vllm/_custom_ops.py
+32
-21
No files found.
tests/kernels/test_awq.py
View file @
bd37b9fb
import
os
import
pytest
import
torch
from
tests.kernels.utils
import
opcheck
from
vllm
import
_custom_ops
as
ops
# noqa: F401
@
pytest
.
mark
.
skipif
(
not
hasattr
(
torch
.
ops
.
_C
,
"awq_dequantize"
),
reason
=
"AWQ is not supported on this GPU type."
)
def
test_awq_dequantize_opcheck
():
os
.
environ
[
"VLLM_USE_TRITON_AWQ"
]
=
"0"
qweight
=
torch
.
randint
(
-
2000000000
,
...
...
@@ -21,6 +24,8 @@ def test_awq_dequantize_opcheck():
(
qweight
,
scales
,
zeros
,
split_k_iters
,
thx
,
thy
))
@
pytest
.
mark
.
skipif
(
not
hasattr
(
torch
.
ops
.
_C
,
"awq_gemm"
),
reason
=
"AWQ is not supported on this GPU type."
)
def
test_awq_gemm_opcheck
():
os
.
environ
[
"VLLM_USE_TRITON_AWQ"
]
=
"0"
input
=
torch
.
rand
((
2
,
8192
),
device
=
'cuda'
,
dtype
=
torch
.
float16
)
...
...
tests/kernels/test_awq_marlin.py
View file @
bd37b9fb
...
...
@@ -7,6 +7,7 @@ import torch
from
tests.kernels.utils
import
(
compute_max_diff
,
stack_and_dev
,
torch_moe
,
torch_moe_single
)
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.fused_moe.fused_marlin_moe
import
(
fused_marlin_moe
,
single_marlin_moe
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_topk
...
...
@@ -21,6 +22,9 @@ from vllm.scalar_type import scalar_types
@
pytest
.
mark
.
parametrize
(
"e"
,
[
8
,
64
])
@
pytest
.
mark
.
parametrize
(
"topk"
,
[
2
,
6
])
@
pytest
.
mark
.
parametrize
(
"group_size"
,
[
-
1
,
32
,
64
,
128
])
@
pytest
.
mark
.
skipif
(
not
(
ops
.
supports_moe_ops
and
hasattr
(
torch
.
ops
.
_moe_C
,
"marlin_gemm_moe"
)),
reason
=
"Marlin is not supported on this GPU type."
)
def
test_fused_marlin_moe_awq
(
m
:
int
,
n
:
int
,
...
...
vllm/_custom_ops.py
View file @
bd37b9fb
import
contextlib
import
functools
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch.library
import
vllm.envs
as
envs
from
vllm._core_ext
import
ScalarType
...
...
@@ -25,6 +26,16 @@ with contextlib.suppress(ImportError):
import
vllm._moe_C
# noqa: F401
supports_moe_ops
=
True
if
TYPE_CHECKING
:
def
register_fake
(
fn
):
return
lambda
name
:
fn
else
:
try
:
from
torch.library
import
register_fake
except
ImportError
:
from
torch.library
import
impl_abstract
as
register_fake
def
hint_on_error
(
fn
):
...
...
@@ -266,7 +277,7 @@ def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
if
hasattr
(
torch
.
ops
.
_C
,
"gptq_gemm"
):
@
torch
.
library
.
register_fake
(
"_C::gptq_gemm"
)
@
register_fake
(
"_C::gptq_gemm"
)
def
_gptq_gemm_fake
(
a
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
b_gptq_qzeros
:
torch
.
Tensor
,
b_gptq_scales
:
torch
.
Tensor
,
b_g_idx
:
torch
.
Tensor
,
...
...
@@ -301,7 +312,7 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
if
hasattr
(
torch
.
ops
.
_C
,
"gptq_marlin_24_gemm"
):
@
torch
.
library
.
register_fake
(
"_C::gptq_marlin_24_gemm"
)
@
register_fake
(
"_C::gptq_marlin_24_gemm"
)
def
_gptq_marlin_24_gemm_fake
(
a
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
b_meta
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
...
...
@@ -309,7 +320,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
size_n
:
int
,
size_k
:
int
)
->
torch
.
Tensor
:
return
torch
.
empty
((
size_m
,
size_n
),
device
=
a
.
device
,
dtype
=
a
.
dtype
)
@
torch
.
library
.
register_fake
(
"_C::gptq_marlin_gemm"
)
@
register_fake
(
"_C::gptq_marlin_gemm"
)
def
_gptq_marlin_gemm_fake
(
a
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
...
...
@@ -326,12 +337,12 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
use_fp32_reduce
:
bool
=
False
)
->
torch
.
Tensor
:
return
torch
.
empty
((
size_m
,
size_n
),
device
=
a
.
device
,
dtype
=
a
.
dtype
)
@
torch
.
library
.
register_fake
(
"_C::ggml_dequantize"
)
@
register_fake
(
"_C::ggml_dequantize"
)
def
_ggml_dequantize_fake
(
W
:
torch
.
Tensor
,
quant_type
:
int
,
m
:
int
,
n
:
int
)
->
torch
.
Tensor
:
return
torch
.
empty
((
m
,
n
),
dtype
=
torch
.
float16
,
device
=
W
.
device
)
@
torch
.
library
.
register_fake
(
"_C::ggml_mul_mat_vec_a8"
)
@
register_fake
(
"_C::ggml_mul_mat_vec_a8"
)
def
_ggml_mul_mat_vec_a8_fake
(
W
:
torch
.
Tensor
,
X
:
torch
.
Tensor
,
...
...
@@ -340,7 +351,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
)
->
torch
.
Tensor
:
return
torch
.
empty
((
1
,
row
),
dtype
=
torch
.
float16
,
device
=
W
.
device
)
@
torch
.
library
.
register_fake
(
"_C::ggml_mul_mat_a8"
)
@
register_fake
(
"_C::ggml_mul_mat_a8"
)
def
_ggml_mul_mat_a8_fake
(
W
:
torch
.
Tensor
,
X
:
torch
.
Tensor
,
...
...
@@ -350,7 +361,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
batch
=
X
.
size
(
0
)
return
torch
.
empty
((
batch
,
row
),
dtype
=
torch
.
float16
,
device
=
W
.
device
)
@
torch
.
library
.
register_fake
(
"_C::marlin_qqq_gemm"
)
@
register_fake
(
"_C::marlin_qqq_gemm"
)
def
_marlin_qqq_gemm_fake
(
a
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
s_tok
:
torch
.
Tensor
,
s_ch
:
torch
.
Tensor
,
s_group
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
...
...
@@ -360,7 +371,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
dtype
=
torch
.
float16
,
device
=
a
.
device
)
@
torch
.
library
.
register_fake
(
"_C::marlin_gemm"
)
@
register_fake
(
"_C::marlin_gemm"
)
def
_marlin_gemm_fake
(
a
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
size_m
:
int
,
size_n
:
int
,
...
...
@@ -369,7 +380,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
dtype
=
torch
.
float16
,
device
=
a
.
device
)
@
torch
.
library
.
register_fake
(
"_C::awq_dequantize"
)
@
register_fake
(
"_C::awq_dequantize"
)
def
_awq_dequantize_fake
(
qweight
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
zeros
:
torch
.
Tensor
,
split_k_iters
:
int
,
thx
:
int
,
thy
:
int
)
->
torch
.
Tensor
:
...
...
@@ -380,7 +391,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
dtype
=
scales
.
dtype
,
device
=
scales
.
device
)
@
torch
.
library
.
register_fake
(
"_C::awq_gemm"
)
@
register_fake
(
"_C::awq_gemm"
)
def
_awq_gemm_fake
(
input
:
torch
.
Tensor
,
qweight
:
torch
.
Tensor
,
qzeros
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
split_k_iters
:
int
)
->
torch
.
Tensor
:
...
...
@@ -389,7 +400,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
dtype
=
input
.
dtype
,
device
=
input
.
device
).
sum
(
0
)
@
torch
.
library
.
register_fake
(
"_C::aqlm_gemm"
)
@
register_fake
(
"_C::aqlm_gemm"
)
def
_aqlm_gemm_fake
(
input
:
torch
.
Tensor
,
codes
:
torch
.
Tensor
,
codebooks
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
codebook_partition_sizes
:
List
[
int
],
...
...
@@ -405,7 +416,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
output_sizes
.
append
(
-
1
)
return
flat_output
.
reshape
(
tuple
(
output_sizes
))
@
torch
.
library
.
register_fake
(
"_C::aqlm_dequant"
)
@
register_fake
(
"_C::aqlm_dequant"
)
def
_aqlm_dequant_fake
(
codes
:
torch
.
Tensor
,
codebooks
:
torch
.
Tensor
,
codebook_partition_sizes
:
List
[
int
])
->
torch
.
Tensor
:
...
...
@@ -415,14 +426,14 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
dtype
=
codebooks
.
dtype
,
device
=
codebooks
.
device
)
@
torch
.
library
.
register_fake
(
"_C::fp8_marlin_gemm"
)
@
register_fake
(
"_C::fp8_marlin_gemm"
)
def
_fp8_marlin_gemm_fake
(
a
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
num_bits
:
int
,
size_m
:
int
,
size_n
:
int
,
size_k
:
int
)
->
torch
.
Tensor
:
return
torch
.
empty
((
size_m
,
size_n
),
dtype
=
a
.
dtype
,
device
=
a
.
device
)
@
torch
.
library
.
register_fake
(
"_C::machete_gemm"
)
@
register_fake
(
"_C::machete_gemm"
)
def
machete_gemm_fake
(
a
:
torch
.
Tensor
,
# Should be the tensor returned by machete_prepack_B
...
...
@@ -440,13 +451,13 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
n
=
b_q
.
size
(
1
)
return
torch
.
empty
((
m
,
n
),
device
=
a
.
device
,
dtype
=
a
.
dtype
)
@
torch
.
library
.
register_fake
(
"_C::machete_prepack_B"
)
@
register_fake
(
"_C::machete_prepack_B"
)
def
machete_prepack_B_fake
(
b_q_weight
:
torch
.
Tensor
,
b_type
:
ScalarType
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
b_q_weight
,
memory_format
=
torch
.
contiguous_format
)
@
torch
.
library
.
register_fake
(
"_C::causal_conv1d_fwd"
)
@
register_fake
(
"_C::causal_conv1d_fwd"
)
def
causal_conv1d_fwd_fake
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias_
:
Optional
[
torch
.
Tensor
],
conv_states
:
Optional
[
torch
.
Tensor
],
...
...
@@ -456,7 +467,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
silu_activation
:
bool
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
x
)
@
torch
.
library
.
register_fake
(
"_C::causal_conv1d_update"
)
@
register_fake
(
"_C::causal_conv1d_update"
)
def
causal_conv1d_update_fake
(
x
:
torch
.
Tensor
,
conv_state
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias_
:
Optional
[
torch
.
Tensor
],
silu_activation
:
bool
,
...
...
@@ -464,7 +475,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
conv_state_indices
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
return
torch
.
empty_like
(
x
)
@
torch
.
library
.
register_fake
(
"_C::selective_scan_fwd"
)
@
register_fake
(
"_C::selective_scan_fwd"
)
def
selective_scan_fwd_fake
(
u
:
torch
.
Tensor
,
delta
:
torch
.
Tensor
,
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
C
:
torch
.
Tensor
,
D_
:
Optional
[
torch
.
Tensor
],
...
...
@@ -639,7 +650,7 @@ def machete_prepack_B(b_q_weight: torch.Tensor,
if
hasattr
(
torch
.
ops
.
_C
,
"permute_cols"
):
@
torch
.
library
.
register_fake
(
"_C::permute_cols"
)
@
register_fake
(
"_C::permute_cols"
)
def
_permute_cols_fake
(
a
:
torch
.
Tensor
,
perm
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
a
)
...
...
@@ -837,7 +848,7 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
if
supports_moe_ops
and
hasattr
(
torch
.
ops
.
_moe_C
,
"marlin_gemm_moe"
):
@
torch
.
library
.
register_fake
(
"_moe_C::marlin_gemm_moe"
)
@
register_fake
(
"_moe_C::marlin_gemm_moe"
)
def
marlin_gemm_moe_fake
(
a
:
torch
.
Tensor
,
b_q_weights
:
torch
.
Tensor
,
sorted_ids
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment