Unverified Commit 04437e31 authored by Luka Govedič's avatar Luka Govedič Committed by GitHub
Browse files

[Bugfix] [torch.compile] Add Dynamo metrics context during compilation (#15639)


Signed-off-by: default avatarluka <luka@neuralmagic.com>
parent 038beded
...@@ -2,21 +2,20 @@ ...@@ -2,21 +2,20 @@
from __future__ import annotations from __future__ import annotations
from typing import Any from typing import Any, Union
import pytest import pytest
import torch import torch
from tests.quantization.utils import is_quant_method_supported from tests.quantization.utils import is_quant_method_supported
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.config import CompilationLevel from vllm.config import CompilationConfig, CompilationLevel
from vllm.platforms import current_platform from vllm.platforms import current_platform
from ..utils import create_new_process_for_each_test from ..utils import create_new_process_for_each_test
@pytest.fixture(params=None, name="model_info") def models_list(all: bool):
def models_list_fixture(request):
TEST_MODELS: list[tuple[str, dict[str, Any]]] = [ TEST_MODELS: list[tuple[str, dict[str, Any]]] = [
("facebook/opt-125m", {}), ("facebook/opt-125m", {}),
("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", { ("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", {
...@@ -33,6 +32,9 @@ def models_list_fixture(request): ...@@ -33,6 +32,9 @@ def models_list_fixture(request):
("meta-llama/Llama-3.2-1B-Instruct", {}), ("meta-llama/Llama-3.2-1B-Instruct", {}),
] ]
if not all:
return TEST_MODELS
if is_quant_method_supported("aqlm"): if is_quant_method_supported("aqlm"):
TEST_MODELS.append(("ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf", { TEST_MODELS.append(("ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf", {
"quantization": "aqlm" "quantization": "aqlm"
...@@ -77,7 +79,7 @@ def models_list_fixture(request): ...@@ -77,7 +79,7 @@ def models_list_fixture(request):
"optimization_level", "optimization_level",
[CompilationLevel.DYNAMO_ONCE, CompilationLevel.PIECEWISE], [CompilationLevel.DYNAMO_ONCE, CompilationLevel.PIECEWISE],
) )
@pytest.mark.parametrize("model_info", "", indirect=True) @pytest.mark.parametrize("model_info", models_list(all=True))
@create_new_process_for_each_test() @create_new_process_for_each_test()
def test_full_graph( def test_full_graph(
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
...@@ -91,25 +93,50 @@ def test_full_graph( ...@@ -91,25 +93,50 @@ def test_full_graph(
m.setenv("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") m.setenv("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1")
print(f"MODEL={model}") print(f"MODEL={model}")
prompts = [ run_model(optimization_level, model, model_kwargs)
"Hello, my name is",
"The president of the United States is",
"The capital of France is", # TODO(luka) add other supported compilation config scenarios here
"The future of AI is", @pytest.mark.parametrize(
] "compilation_config",
sampling_params = SamplingParams(temperature=0) # additional compile sizes
llm = LLM( [
model=model, CompilationConfig(level=CompilationLevel.PIECEWISE,
enforce_eager=True, compile_sizes=[1, 2])
tensor_parallel_size=1, ])
disable_custom_all_reduce=True, # only test some of the models
compilation_config=optimization_level, @pytest.mark.parametrize("model_info", models_list(all=False))
**model_kwargs, @create_new_process_for_each_test()
) def test_custom_compile_config(
outputs = llm.generate(prompts, sampling_params) model_info: tuple[str, dict[str, Any]],
compilation_config: CompilationConfig,
# Print the outputs. ):
for output in outputs: model, model_kwargs = model_info
prompt = output.prompt print(f"MODEL={model}")
generated_text = output.outputs[0].text run_model(compilation_config, model, model_kwargs)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
def run_model(compile_config: Union[int, CompilationConfig], model: str,
model_kwargs: dict[str, Any]):
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0)
llm = LLM(
model=model,
enforce_eager=True,
tensor_parallel_size=1,
disable_custom_all_reduce=True,
compilation_config=compile_config,
**model_kwargs,
)
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import contextlib
import copy import copy
import hashlib import hashlib
import importlib.metadata
import os import os
from contextlib import ExitStack from contextlib import ExitStack
from typing import Any, Callable, Dict, List, Optional, Tuple from typing import Any, Callable, Dict, List, Optional, Tuple
...@@ -9,6 +11,7 @@ from unittest.mock import patch ...@@ -9,6 +11,7 @@ from unittest.mock import patch
import torch import torch
import torch._inductor.compile_fx import torch._inductor.compile_fx
import torch.fx as fx import torch.fx as fx
from packaging.version import Version
from vllm.config import VllmConfig from vllm.config import VllmConfig
...@@ -285,6 +288,9 @@ class InductorAdaptor(CompilerInterface): ...@@ -285,6 +288,9 @@ class InductorAdaptor(CompilerInterface):
"torch._inductor.codecache.FxGraphCache._check_can_cache", "torch._inductor.codecache.FxGraphCache._check_can_cache",
_check_can_cache)) _check_can_cache))
# Dynamo metrics context, see method for more details.
stack.enter_context(self.metrics_context())
compiled_graph = compile_fx( compiled_graph = compile_fx(
graph, graph,
example_inputs, example_inputs,
...@@ -309,8 +315,14 @@ class InductorAdaptor(CompilerInterface): ...@@ -309,8 +315,14 @@ class InductorAdaptor(CompilerInterface):
hash_str = handle[0] hash_str = handle[0]
from torch._inductor.codecache import FxGraphCache from torch._inductor.codecache import FxGraphCache
with patch("torch._inductor.codecache.FxGraphCache._get_shape_env", with ExitStack() as exit_stack:
lambda *args, **kwargs: AlwaysHitShapeEnv()): exit_stack.enter_context(
patch("torch._inductor.codecache.FxGraphCache._get_shape_env",
lambda *args, **kwargs: AlwaysHitShapeEnv()))
# Dynamo metrics context, see method for more details.
exit_stack.enter_context(self.metrics_context())
if torch.__version__.startswith("2.5"): if torch.__version__.startswith("2.5"):
inductor_compiled_graph = FxGraphCache._lookup_graph( inductor_compiled_graph = FxGraphCache._lookup_graph(
hash_str, example_inputs, True, False) hash_str, example_inputs, True, False)
...@@ -351,6 +363,28 @@ class InductorAdaptor(CompilerInterface): ...@@ -351,6 +363,28 @@ class InductorAdaptor(CompilerInterface):
return compiled_graph return compiled_graph
def metrics_context(self) -> contextlib.AbstractContextManager:
"""
This method returns the Dynamo metrics context (if it exists,
otherwise a null context). It is used by various compile components.
Present in torch>=2.6, it's used inside FxGraphCache in
torch==2.6 (but not after). It might also be used in various other
torch.compile internal functions.
Because it is re-entrant, we always set it (even if entering via Dynamo
and the context was already entered). We might want to revisit if it
should be set at a different level of compilation.
This is likely a bug in PyTorch: public APIs should not rely on
manually setting up internal contexts. But we also rely on non-public
APIs which might not provide these guarantees.
"""
if Version(importlib.metadata.version('torch')) >= Version("2.6"):
import torch._dynamo.utils
return torch._dynamo.utils.get_metrics_context()
else:
return contextlib.nullcontext()
class EagerAdaptor(CompilerInterface): class EagerAdaptor(CompilerInterface):
name = "eager" name = "eager"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment