Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9cdde472
Unverified
Commit
9cdde472
authored
Apr 09, 2025
by
Luka Govedič
Committed by
GitHub
Apr 08, 2025
Browse files
[BugFix] Fix fusion test and add them to CI (#16287)
Signed-off-by:
luka
<
luka@neuralmagic.com
>
parent
b1eb4ca1
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
74 additions
and
49 deletions
+74
-49
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+8
-1
tests/compile/test_full_graph.py
tests/compile/test_full_graph.py
+58
-45
tests/compile/test_fusion.py
tests/compile/test_fusion.py
+8
-3
No files found.
.buildkite/test-pipeline.yaml
View file @
9cdde472
...
...
@@ -292,6 +292,14 @@ steps:
command
:
pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py
parallelism
:
4
-
label
:
PyTorch Compilation Unit Tests
source_file_dependencies
:
-
vllm/
-
tests/compile
commands
:
-
pytest -v -s compile/test_pass_manager.py
-
pytest -v -s compile/test_fusion.py
-
label
:
PyTorch Fullgraph Smoke Test
# 9min
source_file_dependencies
:
-
vllm/
...
...
@@ -301,7 +309,6 @@ steps:
# these tests need to be separated, cannot combine
-
pytest -v -s compile/piecewise/test_simple.py
-
pytest -v -s compile/piecewise/test_toy_llama.py
-
pytest -v -s compile/test_pass_manager.py
-
label
:
PyTorch Fullgraph Test
# 18min
source_file_dependencies
:
...
...
tests/compile/test_full_graph.py
View file @
9cdde472
...
...
@@ -2,7 +2,7 @@
from
__future__
import
annotations
from
typing
import
Any
,
Union
from
typing
import
Any
,
Optional
,
Union
import
pytest
import
torch
...
...
@@ -15,7 +15,7 @@ from vllm.platforms import current_platform
from
..utils
import
create_new_process_for_each_test
def
models_list
(
all
:
bool
):
def
models_list
(
*
,
all
:
bool
=
True
,
keywords
:
Optional
[
list
[
str
]]
=
None
):
TEST_MODELS
:
list
[
tuple
[
str
,
dict
[
str
,
Any
]]]
=
[
(
"facebook/opt-125m"
,
{}),
(
"nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change"
,
{
...
...
@@ -32,9 +32,7 @@ def models_list(all: bool):
(
"meta-llama/Llama-3.2-1B-Instruct"
,
{}),
]
if
not
all
:
return
TEST_MODELS
if
all
:
if
is_quant_method_supported
(
"aqlm"
):
TEST_MODELS
.
append
((
"ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf"
,
{
"quantization"
:
"aqlm"
...
...
@@ -72,8 +70,13 @@ def models_list(all: bool):
"quantization"
:
"AWQ"
}))
if
keywords
is
None
:
return
TEST_MODELS
# filter by keywords
pred
=
lambda
model
:
any
(
keyword
in
model
[
0
]
for
keyword
in
keywords
)
return
list
(
filter
(
pred
,
TEST_MODELS
))
@
pytest
.
mark
.
parametrize
(
"optimization_level"
,
...
...
@@ -96,20 +99,30 @@ def test_full_graph(
run_model
(
optimization_level
,
model
,
model_kwargs
)
PassConfig
=
CompilationConfig
.
PassConfig
# TODO(luka) add other supported compilation config scenarios here
@
pytest
.
mark
.
parametrize
(
"compilation_config"
,
# additional compile sizes
"compilation_config, model_info"
,
[
CompilationConfig
(
level
=
CompilationLevel
.
PIECEWISE
,
compile_sizes
=
[
1
,
2
])
# additional compile sizes, only some of the models
(
CompilationConfig
(
level
=
CompilationLevel
.
PIECEWISE
,
compile_sizes
=
[
1
,
2
]),
model
)
for
model
in
models_list
(
all
=
False
)
]
+
[
# RMSNorm + quant fusion, only 8-bit quant models
(
CompilationConfig
(
level
=
CompilationLevel
.
PIECEWISE
,
custom_ops
=
[
"+rms_norm"
],
pass_config
=
PassConfig
(
enable_fusion
=
True
,
enable_noop
=
True
)),
model
)
for
model
in
models_list
(
keywords
=
[
"FP8-dynamic"
,
"quantized.w8a8"
])
])
# only test some of the models
@
pytest
.
mark
.
parametrize
(
"model_info"
,
models_list
(
all
=
False
))
@
create_new_process_for_each_test
()
def
test_custom_compile_config
(
model_info
:
tuple
[
str
,
dict
[
str
,
Any
]],
compilation_config
:
CompilationConfig
,
model_info
:
tuple
[
str
,
dict
[
str
,
Any
]],
):
model
,
model_kwargs
=
model_info
print
(
f
"MODEL=
{
model
}
"
)
...
...
tests/compile/test_fusion.py
View file @
9cdde472
...
...
@@ -44,12 +44,17 @@ class TestModel(torch.nn.Module):
resid
=
torch
.
sqrt
(
x
)
y
=
self
.
norm
[
0
](
x
)
x2
=
self
.
fp8_linear
.
apply
(
y
,
self
.
w
[
0
],
self
.
wscale
[
0
],
self
.
scale
[
0
])
x2
=
self
.
fp8_linear
.
apply
(
y
,
self
.
w
[
0
],
self
.
wscale
[
0
],
input_scale
=
self
.
scale
[
0
])
# make sure resid is used for replacement to work
y2
,
resid
=
self
.
norm
[
1
](
x2
,
resid
)
x3
=
self
.
fp8_linear
.
apply
(
y2
,
self
.
w
[
1
],
self
.
wscale
[
1
],
self
.
scale
[
1
])
x3
=
self
.
fp8_linear
.
apply
(
y2
,
self
.
w
[
1
],
self
.
wscale
[
1
],
input_scale
=
self
.
scale
[
1
])
y3
,
resid
=
self
.
norm
[
2
](
x3
,
resid
)
# use resid here
return
y3
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment