Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
04437e31
Unverified
Commit
04437e31
authored
Mar 28, 2025
by
Luka Govedič
Committed by
GitHub
Mar 28, 2025
Browse files
[Bugfix] [torch.compile] Add Dynamo metrics context during compilation (#15639)
Signed-off-by:
luka
<
luka@neuralmagic.com
>
parent
038beded
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
90 additions
and
29 deletions
+90
-29
tests/compile/test_full_graph.py
tests/compile/test_full_graph.py
+54
-27
vllm/compilation/compiler_interface.py
vllm/compilation/compiler_interface.py
+36
-2
No files found.
tests/compile/test_full_graph.py
View file @
04437e31
...
...
@@ -2,21 +2,20 @@
from
__future__
import
annotations
from
typing
import
Any
from
typing
import
Any
,
Union
import
pytest
import
torch
from
tests.quantization.utils
import
is_quant_method_supported
from
vllm
import
LLM
,
SamplingParams
from
vllm.config
import
CompilationLevel
from
vllm.config
import
CompilationConfig
,
CompilationLevel
from
vllm.platforms
import
current_platform
from
..utils
import
create_new_process_for_each_test
@
pytest
.
fixture
(
params
=
None
,
name
=
"model_info"
)
def
models_list_fixture
(
request
):
def
models_list
(
all
:
bool
):
TEST_MODELS
:
list
[
tuple
[
str
,
dict
[
str
,
Any
]]]
=
[
(
"facebook/opt-125m"
,
{}),
(
"nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change"
,
{
...
...
@@ -33,6 +32,9 @@ def models_list_fixture(request):
(
"meta-llama/Llama-3.2-1B-Instruct"
,
{}),
]
if
not
all
:
return
TEST_MODELS
if
is_quant_method_supported
(
"aqlm"
):
TEST_MODELS
.
append
((
"ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf"
,
{
"quantization"
:
"aqlm"
...
...
@@ -77,7 +79,7 @@ def models_list_fixture(request):
"optimization_level"
,
[
CompilationLevel
.
DYNAMO_ONCE
,
CompilationLevel
.
PIECEWISE
],
)
@
pytest
.
mark
.
parametrize
(
"model_info"
,
""
,
indirect
=
True
)
@
pytest
.
mark
.
parametrize
(
"model_info"
,
models_list
(
all
=
True
)
)
@
create_new_process_for_each_test
()
def
test_full_graph
(
monkeypatch
:
pytest
.
MonkeyPatch
,
...
...
@@ -91,25 +93,50 @@ def test_full_graph(
m
.
setenv
(
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"
,
"1"
)
print
(
f
"MODEL=
{
model
}
"
)
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
]
sampling_params
=
SamplingParams
(
temperature
=
0
)
llm
=
LLM
(
model
=
model
,
enforce_eager
=
True
,
tensor_parallel_size
=
1
,
disable_custom_all_reduce
=
True
,
compilation_config
=
optimization_level
,
**
model_kwargs
,
)
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
# Print the outputs.
for
output
in
outputs
:
prompt
=
output
.
prompt
generated_text
=
output
.
outputs
[
0
].
text
print
(
f
"Prompt:
{
prompt
!
r
}
, Generated text:
{
generated_text
!
r
}
"
)
run_model
(
optimization_level
,
model
,
model_kwargs
)
# TODO(luka) add other supported compilation config scenarios here
@
pytest
.
mark
.
parametrize
(
"compilation_config"
,
# additional compile sizes
[
CompilationConfig
(
level
=
CompilationLevel
.
PIECEWISE
,
compile_sizes
=
[
1
,
2
])
])
# only test some of the models
@
pytest
.
mark
.
parametrize
(
"model_info"
,
models_list
(
all
=
False
))
@
create_new_process_for_each_test
()
def
test_custom_compile_config
(
model_info
:
tuple
[
str
,
dict
[
str
,
Any
]],
compilation_config
:
CompilationConfig
,
):
model
,
model_kwargs
=
model_info
print
(
f
"MODEL=
{
model
}
"
)
run_model
(
compilation_config
,
model
,
model_kwargs
)
def
run_model
(
compile_config
:
Union
[
int
,
CompilationConfig
],
model
:
str
,
model_kwargs
:
dict
[
str
,
Any
]):
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
]
sampling_params
=
SamplingParams
(
temperature
=
0
)
llm
=
LLM
(
model
=
model
,
enforce_eager
=
True
,
tensor_parallel_size
=
1
,
disable_custom_all_reduce
=
True
,
compilation_config
=
compile_config
,
**
model_kwargs
,
)
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
# Print the outputs.
for
output
in
outputs
:
prompt
=
output
.
prompt
generated_text
=
output
.
outputs
[
0
].
text
print
(
f
"Prompt:
{
prompt
!
r
}
, Generated text:
{
generated_text
!
r
}
"
)
vllm/compilation/compiler_interface.py
View file @
04437e31
# SPDX-License-Identifier: Apache-2.0
import
contextlib
import
copy
import
hashlib
import
importlib.metadata
import
os
from
contextlib
import
ExitStack
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
...
...
@@ -9,6 +11,7 @@ from unittest.mock import patch
import
torch
import
torch._inductor.compile_fx
import
torch.fx
as
fx
from
packaging.version
import
Version
from
vllm.config
import
VllmConfig
...
...
@@ -285,6 +288,9 @@ class InductorAdaptor(CompilerInterface):
"torch._inductor.codecache.FxGraphCache._check_can_cache"
,
_check_can_cache
))
# Dynamo metrics context, see method for more details.
stack
.
enter_context
(
self
.
metrics_context
())
compiled_graph
=
compile_fx
(
graph
,
example_inputs
,
...
...
@@ -309,8 +315,14 @@ class InductorAdaptor(CompilerInterface):
hash_str
=
handle
[
0
]
from
torch._inductor.codecache
import
FxGraphCache
with
patch
(
"torch._inductor.codecache.FxGraphCache._get_shape_env"
,
lambda
*
args
,
**
kwargs
:
AlwaysHitShapeEnv
()):
with
ExitStack
()
as
exit_stack
:
exit_stack
.
enter_context
(
patch
(
"torch._inductor.codecache.FxGraphCache._get_shape_env"
,
lambda
*
args
,
**
kwargs
:
AlwaysHitShapeEnv
()))
# Dynamo metrics context, see method for more details.
exit_stack
.
enter_context
(
self
.
metrics_context
())
if
torch
.
__version__
.
startswith
(
"2.5"
):
inductor_compiled_graph
=
FxGraphCache
.
_lookup_graph
(
hash_str
,
example_inputs
,
True
,
False
)
...
...
@@ -351,6 +363,28 @@ class InductorAdaptor(CompilerInterface):
return
compiled_graph
def
metrics_context
(
self
)
->
contextlib
.
AbstractContextManager
:
"""
This method returns the Dynamo metrics context (if it exists,
otherwise a null context). It is used by various compile components.
Present in torch>=2.6, it's used inside FxGraphCache in
torch==2.6 (but not after). It might also be used in various other
torch.compile internal functions.
Because it is re-entrant, we always set it (even if entering via Dynamo
and the context was already entered). We might want to revisit if it
should be set at a different level of compilation.
This is likely a bug in PyTorch: public APIs should not rely on
manually setting up internal contexts. But we also rely on non-public
APIs which might not provide these guarantees.
"""
if
Version
(
importlib
.
metadata
.
version
(
'torch'
))
>=
Version
(
"2.6"
):
import
torch._dynamo.utils
return
torch
.
_dynamo
.
utils
.
get_metrics_context
()
else
:
return
contextlib
.
nullcontext
()
class
EagerAdaptor
(
CompilerInterface
):
name
=
"eager"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment