Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7b2f28de
Unverified
Commit
7b2f28de
authored
May 14, 2025
by
Charlie Fu
Committed by
GitHub
May 13, 2025
Browse files
[AMD][torch.compile] Enable silu+fp8_quant fusion for rocm (#18082)
Signed-off-by:
charlifu
<
charlifu@amd.com
>
parent
2d912fb6
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
14 additions
and
9 deletions
+14
-9
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+1
-0
csrc/quantization/activation_kernels.cu
csrc/quantization/activation_kernels.cu
+2
-1
tests/compile/test_silu_mul_quant_fusion.py
tests/compile/test_silu_mul_quant_fusion.py
+3
-3
tests/kernels/quantization/test_rocm_skinny_gemms.py
tests/kernels/quantization/test_rocm_skinny_gemms.py
+3
-2
tests/kernels/test_fused_quant_activation.py
tests/kernels/test_fused_quant_activation.py
+3
-2
vllm/compilation/activation_quant_fusion.py
vllm/compilation/activation_quant_fusion.py
+2
-1
No files found.
.buildkite/test-pipeline.yaml
View file @
7b2f28de
...
...
@@ -309,6 +309,7 @@ steps:
commands
:
-
pytest -v -s compile/test_pass_manager.py
-
pytest -v -s compile/test_fusion.py
-
pytest -v -s compile/test_silu_mul_quant_fusion.py
-
pytest -v -s compile/test_sequence_parallelism.py
-
label
:
PyTorch Fullgraph Smoke Test
# 9min
...
...
csrc/quantization/activation_kernels.cu
View file @
7b2f28de
...
...
@@ -112,7 +112,8 @@ __global__ void act_and_mul_quant_kernel(
void
silu_and_mul_quant
(
torch
::
Tensor
&
out
,
// [..., d]
torch
::
Tensor
&
input
,
// [..., 2 * d]
torch
::
Tensor
&
scale
)
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat8_e4m3fn
||
out
.
dtype
()
==
torch
::
kFloat8_e4m3fnuz
);
TORCH_CHECK
(
input
.
dtype
()
==
torch
::
kFloat16
||
input
.
dtype
()
==
torch
::
kBFloat16
);
TORCH_CHECK
(
input
.
size
(
-
1
)
%
2
==
0
);
...
...
tests/compile/test_silu_mul_quant_fusion.py
View file @
7b2f28de
...
...
@@ -27,8 +27,8 @@ class TestModel(torch.nn.Module):
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
256
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
64
])
@
pytest
.
mark
.
skipif
(
envs
.
VLLM_TARGET_DEVICE
!=
"cuda"
,
reason
=
"Only test on CUDA"
)
@
pytest
.
mark
.
skipif
(
envs
.
VLLM_TARGET_DEVICE
not
in
[
"cuda"
,
"rocm"
]
,
reason
=
"Only test on CUDA
and ROCm
"
)
def
test_fusion_silu_and_mul_quant
(
num_tokens
,
hidden_size
):
torch
.
set_default_device
(
"cuda"
)
torch
.
set_default_dtype
(
torch
.
float16
)
...
...
@@ -36,7 +36,7 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size):
# Reshape pass is needed for the fusion pass to work
config
=
VllmConfig
()
config
.
compilation_config
=
CompilationConfig
(
pass_config
=
PassConfig
(
enable_fusion
=
True
,
enable_
reshape
=
True
))
pass_config
=
PassConfig
(
enable_fusion
=
True
,
enable_
noop
=
True
))
fusion_pass
=
ActivationQuantFusionPass
(
config
)
backend
=
TestBackend
(
fusion_pass
)
...
...
tests/kernels/quantization/test_rocm_skinny_gemms.py
View file @
7b2f28de
...
...
@@ -58,8 +58,9 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
@
pytest
.
mark
.
parametrize
(
"m"
,
M
+
[
28672
])
# m >= 16
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_rocm
(),
reason
=
"only test for rocm"
)
@
pytest
.
mark
.
skipif
(
not
(
current_platform
.
is_rocm
()
and
current_platform
.
supports_fp8
()),
reason
=
"only test for rocm fp8"
)
def
test_rocm_wvsplitk_fp8_kernel
(
n
,
k
,
m
,
dtype
,
seed
):
torch
.
manual_seed
(
seed
)
...
...
tests/kernels/test_fused_quant_activation.py
View file @
7b2f28de
...
...
@@ -5,9 +5,10 @@ import torch
import
vllm._custom_ops
as
ops
from
tests.kernels.utils
import
opcheck
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.platforms
import
current_platform
DTYPES
=
[
torch
.
bfloat16
,
torch
.
float16
]
QUANT_DTYPES
=
[
torch
.
float8_e4m3fn
]
QUANT_DTYPES
=
[
current_platform
.
fp8_dtype
()
]
NUM_TOKENS
=
[
1
,
17
,
86
,
1234
,
3045
]
# Arbitrary values for testing
HIDDEN_SIZES
=
[
16
,
48
,
128
,
1562
,
4096
]
# Arbitrary values for testing
SEEDS
=
[
0
]
...
...
@@ -26,7 +27,7 @@ def ref_impl(silu_and_mul: SiluAndMul, x: torch.Tensor,
def
ops_impl
(
x
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
)
->
torch
.
Tensor
:
out_shape
=
(
x
.
shape
[
0
],
x
.
shape
[
1
]
//
2
)
out
=
torch
.
empty
(
out_shape
,
dtype
=
torch
.
torch
.
float8_e4m3fn
,
dtype
=
current_platform
.
fp8_dtype
()
,
device
=
x
.
device
)
torch
.
ops
.
_C
.
silu_and_mul_quant
(
out
,
x
,
scale
)
return
out
...
...
vllm/compilation/activation_quant_fusion.py
View file @
7b2f28de
...
...
@@ -7,6 +7,7 @@ from torch._inductor.pattern_matcher import (PatternMatcherPass, fwd_only,
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
.vllm_inductor_pass
import
VllmInductorPass
...
...
@@ -41,7 +42,7 @@ def empty_bf16(*args, **kwargs):
def
empty_fp8
(
*
args
,
**
kwargs
):
fp8
=
torch
.
float8_e4m3fn
fp8
=
current_platform
.
fp8_dtype
()
return
torch
.
empty
(
*
args
,
**
kwargs
,
dtype
=
fp8
,
device
=
"cuda"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment