Unverified Commit b957aa47 authored by Przemyslaw Tredak's avatar Przemyslaw Tredak Committed by GitHub
Browse files

Fix compatibility with pyTorch 2.0 (#627)


Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
parent bea70f2e
...@@ -22,7 +22,12 @@ if torch.__version__ >= "2.2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1")) ...@@ -22,7 +22,12 @@ if torch.__version__ >= "2.2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))
no_torch_dynamo = lambda recursive=True: lambda func: func no_torch_dynamo = lambda recursive=True: lambda func: func
if torch.__version__ >= "2": if torch.__version__ >= "2":
import torch._dynamo import torch._dynamo
no_torch_dynamo = lambda recursive=True: lambda f: torch._dynamo.disable(f, recursive=recursive) if torch.__version__ >= "2.1":
no_torch_dynamo = lambda recursive=True: lambda f: \
torch._dynamo.disable(f, recursive=recursive)
else:
# no "recursive" option in pyTorch 2.0 - it acts as if recursive was True
no_torch_dynamo = lambda recursive=True: torch._dynamo.disable
def set_jit_fusion_options() -> None: def set_jit_fusion_options() -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment