Unverified Commit 8480fda6 authored by fxmarty's avatar fxmarty Committed by GitHub
Browse files

Fix `GenerationMixin.generate` compatibility with pytorch profiler (#31935)

use torch.compiler.is_compiling() when possible
parent 7f79a973
......@@ -642,12 +642,8 @@ def is_torch_mlu_available(check_device=False):
def is_torchdynamo_available():
if not is_torch_available():
return False
try:
import torch._dynamo as dynamo # noqa: F401
return True
except Exception:
return False
return version.parse(_torch_version) >= version.parse("2.0.0")
def is_torch_compile_available():
......@@ -665,9 +661,15 @@ def is_torchdynamo_compiling():
if not is_torch_available():
return False
try:
import torch._dynamo as dynamo # noqa: F401
# Importing torch._dynamo causes issues with PyTorch profiler (https://github.com/pytorch/pytorch/issues/130622) hence rather relying on `torch.compiler.is_compiling()` when possible.
if version.parse(_torch_version) >= version.parse("2.3.0"):
import torch
return torch.compiler.is_compiling()
else:
import torch._dynamo as dynamo # noqa: F401
return dynamo.is_compiling()
return dynamo.is_compiling()
except Exception:
return False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment