Unverified Commit dec3ef1d authored by Alp Dener's avatar Alp Dener Committed by GitHub
Browse files

[PyTorch] Disabling TorchDynamo for TE activation checkpoint wrapper (#894)



added @torch._disable_dynamo fixed deprecation warnings with torch autocast API for TE checkpoint
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>
parent fe80ca06
...@@ -239,13 +239,13 @@ def _get_active_autocast_contexts(): ...@@ -239,13 +239,13 @@ def _get_active_autocast_contexts():
""" """
autocast_cached = torch.is_autocast_cache_enabled() autocast_cached = torch.is_autocast_cache_enabled()
gpu_autocast_enabled = torch.is_autocast_enabled() gpu_autocast_enabled = torch.is_autocast_enabled('cuda')
gpu_autocast_dtype = torch.get_autocast_gpu_dtype() gpu_autocast_dtype = torch.get_autocast_dtype('cuda')
gpu_autocast_ctx = torch.cuda.amp.autocast( gpu_autocast_ctx = torch.cuda.amp.autocast(
gpu_autocast_enabled, gpu_autocast_dtype, autocast_cached) gpu_autocast_enabled, gpu_autocast_dtype, autocast_cached)
cpu_autocast_enabled = torch.is_autocast_cpu_enabled() cpu_autocast_enabled = torch.is_autocast_enabled('cpu')
cpu_autocast_dtype = torch.get_autocast_cpu_dtype() cpu_autocast_dtype = torch.get_autocast_dtype('cpu')
cpu_autocast_ctx = torch.cpu.amp.autocast( cpu_autocast_ctx = torch.cpu.amp.autocast(
cpu_autocast_enabled, cpu_autocast_dtype, autocast_cached) cpu_autocast_enabled, cpu_autocast_dtype, autocast_cached)
...@@ -557,7 +557,7 @@ def has_te_modules(network): ...@@ -557,7 +557,7 @@ def has_te_modules(network):
# so just assume that it has TE modules just to be safe. # so just assume that it has TE modules just to be safe.
return True return True
@torch._disable_dynamo
def checkpoint( def checkpoint(
function: Callable, function: Callable,
*args: Tuple[torch.Tensor, ...], *args: Tuple[torch.Tensor, ...],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment