Unverified Commit a9656283 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[PyTorch] Fix autocast deprecation warnings (#1277)



* Fix autocast deprecation warnings
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* merge main
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* update
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* resolve comments
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

---------
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent a07eed91
......@@ -16,6 +16,7 @@ from transformer_engine.pytorch.attention.multi_head_attention import MultiheadA
from transformer_engine.pytorch import fp8_model_init
from transformer_engine.pytorch.utils import is_bf16_compatible
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.utils import gpu_autocast_ctx
# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
......@@ -596,7 +597,7 @@ class AdamTest:
gt_ = gt.clone()
# Reference
with torch.cuda.amp.autocast(enabled=True):
with gpu_autocast_ctx(enabled=True):
y = self.model(x)
loss = ((gt - y) ** 2).mean()
......@@ -605,7 +606,7 @@ class AdamTest:
scaler.update()
# DUT
with torch.cuda.amp.autocast(enabled=True):
with gpu_autocast_ctx(enabled=True):
y = self.model_(x)
loss_ = ((gt_ - y) ** 2).mean()
......@@ -647,7 +648,7 @@ class AdamTest:
gt_ = gt.clone()
# Reference
with torch.cuda.amp.autocast(enabled=True):
with gpu_autocast_ctx(enabled=True):
y = self.model(x)
loss = ((gt - y) ** 2).mean()
......@@ -656,7 +657,7 @@ class AdamTest:
scaler.update()
# DUT
with torch.cuda.amp.autocast(enabled=True):
with gpu_autocast_ctx(enabled=True):
y = self.model_(x)
loss_ = ((gt_ - y) ** 2).mean()
......@@ -705,7 +706,7 @@ class AdamTest:
gt_ = gt.clone()
# Reference
with torch.cuda.amp.autocast(enabled=True):
with gpu_autocast_ctx(enabled=True):
y = self.model(x)
loss = ((gt - y) ** 2).mean()
......@@ -714,7 +715,7 @@ class AdamTest:
scaler.update()
# DUT
with torch.cuda.amp.autocast(enabled=True):
with gpu_autocast_ctx(enabled=True):
y = self.model_(x)
loss_ = ((gt_ - y) ** 2).mean()
......
......@@ -19,6 +19,7 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp._common_utils import _get_module_fsdp_state
from torch.distributed.fsdp._traversal_utils import _get_fsdp_states_with_modules
from . import torch_version
from .utils import (
is_non_tn_fp8_gemm_supported,
safely_set_viewless_tensor_data,
......@@ -271,17 +272,36 @@ def _get_active_autocast_contexts():
"""
autocast_cached = torch.is_autocast_cache_enabled()
gpu_autocast_enabled = torch.is_autocast_enabled()
gpu_autocast_dtype = torch.get_autocast_gpu_dtype()
gpu_autocast_ctx = torch.cuda.amp.autocast(
gpu_autocast_enabled, gpu_autocast_dtype, autocast_cached
)
if torch_version() >= (2, 4, 0):
gpu_autocast_enabled = torch.is_autocast_enabled("cuda")
gpu_autocast_dtype = torch.get_autocast_dtype("cuda")
gpu_autocast_ctx = torch.amp.autocast(
"cuda",
enabled=gpu_autocast_enabled,
dtype=gpu_autocast_dtype,
cache_enabled=autocast_cached,
)
cpu_autocast_enabled = torch.is_autocast_cpu_enabled()
cpu_autocast_dtype = torch.get_autocast_cpu_dtype()
cpu_autocast_ctx = torch.cpu.amp.autocast(
cpu_autocast_enabled, cpu_autocast_dtype, autocast_cached
)
cpu_autocast_enabled = torch.is_autocast_enabled("cpu")
cpu_autocast_dtype = torch.get_autocast_dtype("cpu")
cpu_autocast_ctx = torch.amp.autocast(
"cpu",
enabled=cpu_autocast_enabled,
dtype=cpu_autocast_dtype,
cache_enabled=autocast_cached,
)
else:
gpu_autocast_enabled = torch.is_autocast_enabled()
gpu_autocast_dtype = torch.get_autocast_gpu_dtype()
gpu_autocast_ctx = torch.cuda.amp.autocast(
gpu_autocast_enabled, gpu_autocast_dtype, autocast_cached
)
cpu_autocast_enabled = torch.is_autocast_cpu_enabled()
cpu_autocast_dtype = torch.get_autocast_cpu_dtype()
cpu_autocast_ctx = torch.cpu.amp.autocast(
cpu_autocast_enabled, cpu_autocast_dtype, autocast_cached
)
return gpu_autocast_ctx, cpu_autocast_ctx
......
......@@ -9,6 +9,9 @@ from typing import Callable, Optional, Tuple
import torch
from . import torch_version
from .utils import gpu_autocast_ctx
# pylint: disable=unnecessary-lambda-assignment
......@@ -31,13 +34,13 @@ def lazy_compile(func):
jit_fuser = lambda func: func
if torch.__version__ >= "2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))):
if torch_version() >= (2, 0, 0) and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))):
jit_fuser = lazy_compile
# See: https://github.com/NVIDIA/TransformerEngine/issues/597
dropout_fuser = torch.jit.script
if torch.__version__ >= "2.2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))):
if torch_version() >= (2, 2, 0) and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))):
dropout_fuser = lazy_compile
......@@ -49,11 +52,9 @@ no_torch_dynamo = lambda recursive=True: lambda f: torch._dynamo.disable(f, recu
def set_jit_fusion_options() -> None:
"""Set PyTorch JIT layer fusion options."""
# flags required to enable jit fusion kernels
TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split(".")[1])
if TORCH_MAJOR == 2 and TORCH_MINOR >= 2:
if torch_version() >= (2, 2, 0):
pass
elif (TORCH_MAJOR == 2) or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10):
elif torch_version() >= (1, 10, 0):
# nvfuser
torch._C._jit_set_profiling_executor(True)
torch._C._jit_set_profiling_mode(True)
......@@ -122,7 +123,7 @@ def dgelu_fused_(grad_output: torch.Tensor, inp: torch.Tensor) -> torch.Tensor:
def bias_gelu_fused(inp: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
"""Disable native AMP for bias_gelu_fused_"""
with torch.cuda.amp.autocast(enabled=False):
with gpu_autocast_ctx(enabled=False):
if bias is not None and bias.numel() != 0:
return bias_gelu_fused_(inp, bias)
return gelu_fused_(inp)
......@@ -132,7 +133,7 @@ def bgrad_dgelu_fused(
grad_output: torch.Tensor, inp: torch.Tensor, bias: torch.Tensor
) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
"""Disable native AMP for `bgrad_dgelu_fused_`"""
with torch.cuda.amp.autocast(enabled=False):
with gpu_autocast_ctx(enabled=False):
if bias is not None and bias.numel() != 0:
return bgrad_dgelu_fused_(grad_output, inp, bias)
return None, dgelu_fused_(grad_output, inp)
......@@ -173,7 +174,7 @@ def bias_dropout_add_fused_train(
) -> torch.Tensor:
"""Disable native AMP and enable grad for BDA"""
with torch.enable_grad():
with torch.cuda.amp.autocast(enabled=False):
with gpu_autocast_ctx(enabled=False):
return bias_dropout_add_fused_train_(x, bias, residual, prob)
......@@ -189,7 +190,7 @@ def bias_dropout_add_fused_inference(
x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float
) -> torch.Tensor:
"""Disable native AMP for BDA"""
with torch.cuda.amp.autocast(enabled=False):
with gpu_autocast_ctx(enabled=False):
return bias_dropout_add_fused_inference_(x, bias, residual, prob)
......
......@@ -39,6 +39,7 @@ from ..tensor import QuantizedTensor, Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..tensor._internal.float8_tensor_base import Float8TensorBase
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..utils import torch_get_autocast_gpu_dtype
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ...common.recipe import Recipe
from ...debug.pytorch.debug_state import TEDebugState
......@@ -700,7 +701,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
"""Get activation data type for AMP."""
# Native AMP (`torch.autocast`) gets highest priority
if torch.is_autocast_enabled():
self.activation_dtype = torch.get_autocast_gpu_dtype()
self.activation_dtype = torch_get_autocast_gpu_dtype()
return
# All checks after this have already been performed once, thus skip
......
......@@ -10,6 +10,7 @@ from typing import Callable, List, Optional, Tuple, Union
import torch
from transformer_engine.pytorch import torch_version
from transformer_engine.pytorch.module import LayerNormMLP, LayerNorm, RMSNorm
from transformer_engine.debug.pytorch.debug_state import TEDebugState
from transformer_engine.pytorch.attention.multi_head_attention import MultiheadAttention
......@@ -24,6 +25,7 @@ from transformer_engine.pytorch.jit import (
from transformer_engine.pytorch.utils import (
cast_if_needed,
get_default_init_method,
torch_get_autocast_gpu_dtype,
)
from transformer_engine.pytorch.constants import (
AttnMaskTypes,
......@@ -435,9 +437,7 @@ class TransformerLayer(torch.nn.Module):
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None
# Set bias+dropout+add fusion grad_enable execution handler.
TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split(".")[1])
use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10)
use_nvfuser = torch_version() >= (1, 10, 0) and torch_version() < (2, 2, 0)
self.bias_dropout_add_exec_handler = nullcontext if use_nvfuser else torch.enable_grad
if self.bias_dropout_fusion:
......@@ -687,7 +687,7 @@ class TransformerLayer(torch.nn.Module):
# For AMP
if torch.is_autocast_enabled():
hidden_states = cast_if_needed(hidden_states, torch.get_autocast_gpu_dtype())
hidden_states = cast_if_needed(hidden_states, torch_get_autocast_gpu_dtype())
# Self attention.
self_attention_outputs = self.self_attention(
......
......@@ -9,9 +9,10 @@ import math
import os
from typing import Any, Callable, List, Optional, Tuple, Union
import numpy as np
import torch
import transformer_engine.pytorch.cpp_extensions as ext
from . import torch_version
from ..debug.pytorch.debug_quantization import DebugQuantizedTensor
......@@ -596,3 +597,16 @@ def canonicalize_process_group(
if group is None:
return torch.distributed.distributed_c10d._get_default_group()
return group
def torch_get_autocast_gpu_dtype() -> torch.dtype:
"""Get PyTorch autocast GPU dtype."""
if torch_version() >= (2, 4, 0):
return torch.get_autocast_dtype("cuda")
return torch.get_autocast_gpu_dtype()
if torch_version() >= (2, 4, 0):
gpu_autocast_ctx = functools.partial(torch.amp.autocast, device_type="cuda")
else:
gpu_autocast_ctx = torch.cuda.amp.autocast
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment