Unverified Commit 8480fda6 authored by fxmarty's avatar fxmarty Committed by GitHub
Browse files

Fix `GenerationMixin.generate` compatibility with pytorch profiler (#31935)

use torch.compiler.is_compiling() when possible
parent 7f79a973
...@@ -642,12 +642,8 @@ def is_torch_mlu_available(check_device=False): ...@@ -642,12 +642,8 @@ def is_torch_mlu_available(check_device=False):
def is_torchdynamo_available(): def is_torchdynamo_available():
if not is_torch_available(): if not is_torch_available():
return False return False
try:
import torch._dynamo as dynamo # noqa: F401
return True return version.parse(_torch_version) >= version.parse("2.0.0")
except Exception:
return False
def is_torch_compile_available(): def is_torch_compile_available():
...@@ -665,6 +661,12 @@ def is_torchdynamo_compiling(): ...@@ -665,6 +661,12 @@ def is_torchdynamo_compiling():
if not is_torch_available(): if not is_torch_available():
return False return False
try: try:
# Importing torch._dynamo causes issues with PyTorch profiler (https://github.com/pytorch/pytorch/issues/130622) hence rather relying on `torch.compiler.is_compiling()` when possible.
if version.parse(_torch_version) >= version.parse("2.3.0"):
import torch
return torch.compiler.is_compiling()
else:
import torch._dynamo as dynamo # noqa: F401 import torch._dynamo as dynamo # noqa: F401
return dynamo.is_compiling() return dynamo.is_compiling()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment