Unverified Commit b5a7c9f9 authored by Alp Dener's avatar Alp Dener Committed by GitHub
Browse files

[PyTorch] reverting autocast API back to PyTorch v2.3.1 and below (#921)



reverting autocast API back to PyTorch v2.3.1 and below
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>
parent 9d9c3a04
...@@ -239,13 +239,13 @@ def _get_active_autocast_contexts(): ...@@ -239,13 +239,13 @@ def _get_active_autocast_contexts():
""" """
autocast_cached = torch.is_autocast_cache_enabled() autocast_cached = torch.is_autocast_cache_enabled()
gpu_autocast_enabled = torch.is_autocast_enabled('cuda') gpu_autocast_enabled = torch.is_autocast_enabled()
gpu_autocast_dtype = torch.get_autocast_dtype('cuda') gpu_autocast_dtype = torch.get_autocast_gpu_dtype()
gpu_autocast_ctx = torch.cuda.amp.autocast( gpu_autocast_ctx = torch.cuda.amp.autocast(
gpu_autocast_enabled, gpu_autocast_dtype, autocast_cached) gpu_autocast_enabled, gpu_autocast_dtype, autocast_cached)
cpu_autocast_enabled = torch.is_autocast_enabled('cpu') cpu_autocast_enabled = torch.is_autocast_cpu_enabled()
cpu_autocast_dtype = torch.get_autocast_dtype('cpu') cpu_autocast_dtype = torch.get_autocast_cpu_dtype()
cpu_autocast_ctx = torch.cpu.amp.autocast( cpu_autocast_ctx = torch.cpu.amp.autocast(
cpu_autocast_enabled, cpu_autocast_dtype, autocast_cached) cpu_autocast_enabled, cpu_autocast_dtype, autocast_cached)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment