Unverified Commit a9656283 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[PyTorch] Fix autocast deprecation warnings (#1277)



* Fix autocast deprecation warnings
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* merge main
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* update
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* resolve comments
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

---------
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent a07eed91
...@@ -16,6 +16,7 @@ from transformer_engine.pytorch.attention.multi_head_attention import MultiheadA ...@@ -16,6 +16,7 @@ from transformer_engine.pytorch.attention.multi_head_attention import MultiheadA
from transformer_engine.pytorch import fp8_model_init from transformer_engine.pytorch import fp8_model_init
from transformer_engine.pytorch.utils import is_bf16_compatible from transformer_engine.pytorch.utils import is_bf16_compatible
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.utils import gpu_autocast_ctx
# Check if FP8 is supported # Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
...@@ -596,7 +597,7 @@ class AdamTest: ...@@ -596,7 +597,7 @@ class AdamTest:
gt_ = gt.clone() gt_ = gt.clone()
# Reference # Reference
with torch.cuda.amp.autocast(enabled=True): with gpu_autocast_ctx(enabled=True):
y = self.model(x) y = self.model(x)
loss = ((gt - y) ** 2).mean() loss = ((gt - y) ** 2).mean()
...@@ -605,7 +606,7 @@ class AdamTest: ...@@ -605,7 +606,7 @@ class AdamTest:
scaler.update() scaler.update()
# DUT # DUT
with torch.cuda.amp.autocast(enabled=True): with gpu_autocast_ctx(enabled=True):
y = self.model_(x) y = self.model_(x)
loss_ = ((gt_ - y) ** 2).mean() loss_ = ((gt_ - y) ** 2).mean()
...@@ -647,7 +648,7 @@ class AdamTest: ...@@ -647,7 +648,7 @@ class AdamTest:
gt_ = gt.clone() gt_ = gt.clone()
# Reference # Reference
with torch.cuda.amp.autocast(enabled=True): with gpu_autocast_ctx(enabled=True):
y = self.model(x) y = self.model(x)
loss = ((gt - y) ** 2).mean() loss = ((gt - y) ** 2).mean()
...@@ -656,7 +657,7 @@ class AdamTest: ...@@ -656,7 +657,7 @@ class AdamTest:
scaler.update() scaler.update()
# DUT # DUT
with torch.cuda.amp.autocast(enabled=True): with gpu_autocast_ctx(enabled=True):
y = self.model_(x) y = self.model_(x)
loss_ = ((gt_ - y) ** 2).mean() loss_ = ((gt_ - y) ** 2).mean()
...@@ -705,7 +706,7 @@ class AdamTest: ...@@ -705,7 +706,7 @@ class AdamTest:
gt_ = gt.clone() gt_ = gt.clone()
# Reference # Reference
with torch.cuda.amp.autocast(enabled=True): with gpu_autocast_ctx(enabled=True):
y = self.model(x) y = self.model(x)
loss = ((gt - y) ** 2).mean() loss = ((gt - y) ** 2).mean()
...@@ -714,7 +715,7 @@ class AdamTest: ...@@ -714,7 +715,7 @@ class AdamTest:
scaler.update() scaler.update()
# DUT # DUT
with torch.cuda.amp.autocast(enabled=True): with gpu_autocast_ctx(enabled=True):
y = self.model_(x) y = self.model_(x)
loss_ = ((gt_ - y) ** 2).mean() loss_ = ((gt_ - y) ** 2).mean()
......
...@@ -19,6 +19,7 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP ...@@ -19,6 +19,7 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp._common_utils import _get_module_fsdp_state from torch.distributed.fsdp._common_utils import _get_module_fsdp_state
from torch.distributed.fsdp._traversal_utils import _get_fsdp_states_with_modules from torch.distributed.fsdp._traversal_utils import _get_fsdp_states_with_modules
from . import torch_version
from .utils import ( from .utils import (
is_non_tn_fp8_gemm_supported, is_non_tn_fp8_gemm_supported,
safely_set_viewless_tensor_data, safely_set_viewless_tensor_data,
...@@ -271,6 +272,25 @@ def _get_active_autocast_contexts(): ...@@ -271,6 +272,25 @@ def _get_active_autocast_contexts():
""" """
autocast_cached = torch.is_autocast_cache_enabled() autocast_cached = torch.is_autocast_cache_enabled()
if torch_version() >= (2, 4, 0):
gpu_autocast_enabled = torch.is_autocast_enabled("cuda")
gpu_autocast_dtype = torch.get_autocast_dtype("cuda")
gpu_autocast_ctx = torch.amp.autocast(
"cuda",
enabled=gpu_autocast_enabled,
dtype=gpu_autocast_dtype,
cache_enabled=autocast_cached,
)
cpu_autocast_enabled = torch.is_autocast_enabled("cpu")
cpu_autocast_dtype = torch.get_autocast_dtype("cpu")
cpu_autocast_ctx = torch.amp.autocast(
"cpu",
enabled=cpu_autocast_enabled,
dtype=cpu_autocast_dtype,
cache_enabled=autocast_cached,
)
else:
gpu_autocast_enabled = torch.is_autocast_enabled() gpu_autocast_enabled = torch.is_autocast_enabled()
gpu_autocast_dtype = torch.get_autocast_gpu_dtype() gpu_autocast_dtype = torch.get_autocast_gpu_dtype()
gpu_autocast_ctx = torch.cuda.amp.autocast( gpu_autocast_ctx = torch.cuda.amp.autocast(
......
...@@ -9,6 +9,9 @@ from typing import Callable, Optional, Tuple ...@@ -9,6 +9,9 @@ from typing import Callable, Optional, Tuple
import torch import torch
from . import torch_version
from .utils import gpu_autocast_ctx
# pylint: disable=unnecessary-lambda-assignment # pylint: disable=unnecessary-lambda-assignment
...@@ -31,13 +34,13 @@ def lazy_compile(func): ...@@ -31,13 +34,13 @@ def lazy_compile(func):
jit_fuser = lambda func: func jit_fuser = lambda func: func
if torch.__version__ >= "2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))): if torch_version() >= (2, 0, 0) and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))):
jit_fuser = lazy_compile jit_fuser = lazy_compile
# See: https://github.com/NVIDIA/TransformerEngine/issues/597 # See: https://github.com/NVIDIA/TransformerEngine/issues/597
dropout_fuser = torch.jit.script dropout_fuser = torch.jit.script
if torch.__version__ >= "2.2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))): if torch_version() >= (2, 2, 0) and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))):
dropout_fuser = lazy_compile dropout_fuser = lazy_compile
...@@ -49,11 +52,9 @@ no_torch_dynamo = lambda recursive=True: lambda f: torch._dynamo.disable(f, recu ...@@ -49,11 +52,9 @@ no_torch_dynamo = lambda recursive=True: lambda f: torch._dynamo.disable(f, recu
def set_jit_fusion_options() -> None: def set_jit_fusion_options() -> None:
"""Set PyTorch JIT layer fusion options.""" """Set PyTorch JIT layer fusion options."""
# flags required to enable jit fusion kernels # flags required to enable jit fusion kernels
TORCH_MAJOR = int(torch.__version__.split(".")[0]) if torch_version() >= (2, 2, 0):
TORCH_MINOR = int(torch.__version__.split(".")[1])
if TORCH_MAJOR == 2 and TORCH_MINOR >= 2:
pass pass
elif (TORCH_MAJOR == 2) or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10): elif torch_version() >= (1, 10, 0):
# nvfuser # nvfuser
torch._C._jit_set_profiling_executor(True) torch._C._jit_set_profiling_executor(True)
torch._C._jit_set_profiling_mode(True) torch._C._jit_set_profiling_mode(True)
...@@ -122,7 +123,7 @@ def dgelu_fused_(grad_output: torch.Tensor, inp: torch.Tensor) -> torch.Tensor: ...@@ -122,7 +123,7 @@ def dgelu_fused_(grad_output: torch.Tensor, inp: torch.Tensor) -> torch.Tensor:
def bias_gelu_fused(inp: torch.Tensor, bias: torch.Tensor) -> torch.Tensor: def bias_gelu_fused(inp: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
"""Disable native AMP for bias_gelu_fused_""" """Disable native AMP for bias_gelu_fused_"""
with torch.cuda.amp.autocast(enabled=False): with gpu_autocast_ctx(enabled=False):
if bias is not None and bias.numel() != 0: if bias is not None and bias.numel() != 0:
return bias_gelu_fused_(inp, bias) return bias_gelu_fused_(inp, bias)
return gelu_fused_(inp) return gelu_fused_(inp)
...@@ -132,7 +133,7 @@ def bgrad_dgelu_fused( ...@@ -132,7 +133,7 @@ def bgrad_dgelu_fused(
grad_output: torch.Tensor, inp: torch.Tensor, bias: torch.Tensor grad_output: torch.Tensor, inp: torch.Tensor, bias: torch.Tensor
) -> Tuple[Optional[torch.Tensor], torch.Tensor]: ) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
"""Disable native AMP for `bgrad_dgelu_fused_`""" """Disable native AMP for `bgrad_dgelu_fused_`"""
with torch.cuda.amp.autocast(enabled=False): with gpu_autocast_ctx(enabled=False):
if bias is not None and bias.numel() != 0: if bias is not None and bias.numel() != 0:
return bgrad_dgelu_fused_(grad_output, inp, bias) return bgrad_dgelu_fused_(grad_output, inp, bias)
return None, dgelu_fused_(grad_output, inp) return None, dgelu_fused_(grad_output, inp)
...@@ -173,7 +174,7 @@ def bias_dropout_add_fused_train( ...@@ -173,7 +174,7 @@ def bias_dropout_add_fused_train(
) -> torch.Tensor: ) -> torch.Tensor:
"""Disable native AMP and enable grad for BDA""" """Disable native AMP and enable grad for BDA"""
with torch.enable_grad(): with torch.enable_grad():
with torch.cuda.amp.autocast(enabled=False): with gpu_autocast_ctx(enabled=False):
return bias_dropout_add_fused_train_(x, bias, residual, prob) return bias_dropout_add_fused_train_(x, bias, residual, prob)
...@@ -189,7 +190,7 @@ def bias_dropout_add_fused_inference( ...@@ -189,7 +190,7 @@ def bias_dropout_add_fused_inference(
x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float
) -> torch.Tensor: ) -> torch.Tensor:
"""Disable native AMP for BDA""" """Disable native AMP for BDA"""
with torch.cuda.amp.autocast(enabled=False): with gpu_autocast_ctx(enabled=False):
return bias_dropout_add_fused_inference_(x, bias, residual, prob) return bias_dropout_add_fused_inference_(x, bias, residual, prob)
......
...@@ -39,6 +39,7 @@ from ..tensor import QuantizedTensor, Quantizer ...@@ -39,6 +39,7 @@ from ..tensor import QuantizedTensor, Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..tensor._internal.float8_tensor_base import Float8TensorBase from ..tensor._internal.float8_tensor_base import Float8TensorBase
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..utils import torch_get_autocast_gpu_dtype
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ...common.recipe import Recipe from ...common.recipe import Recipe
from ...debug.pytorch.debug_state import TEDebugState from ...debug.pytorch.debug_state import TEDebugState
...@@ -700,7 +701,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -700,7 +701,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
"""Get activation data type for AMP.""" """Get activation data type for AMP."""
# Native AMP (`torch.autocast`) gets highest priority # Native AMP (`torch.autocast`) gets highest priority
if torch.is_autocast_enabled(): if torch.is_autocast_enabled():
self.activation_dtype = torch.get_autocast_gpu_dtype() self.activation_dtype = torch_get_autocast_gpu_dtype()
return return
# All checks after this have already been performed once, thus skip # All checks after this have already been performed once, thus skip
......
...@@ -10,6 +10,7 @@ from typing import Callable, List, Optional, Tuple, Union ...@@ -10,6 +10,7 @@ from typing import Callable, List, Optional, Tuple, Union
import torch import torch
from transformer_engine.pytorch import torch_version
from transformer_engine.pytorch.module import LayerNormMLP, LayerNorm, RMSNorm from transformer_engine.pytorch.module import LayerNormMLP, LayerNorm, RMSNorm
from transformer_engine.debug.pytorch.debug_state import TEDebugState from transformer_engine.debug.pytorch.debug_state import TEDebugState
from transformer_engine.pytorch.attention.multi_head_attention import MultiheadAttention from transformer_engine.pytorch.attention.multi_head_attention import MultiheadAttention
...@@ -24,6 +25,7 @@ from transformer_engine.pytorch.jit import ( ...@@ -24,6 +25,7 @@ from transformer_engine.pytorch.jit import (
from transformer_engine.pytorch.utils import ( from transformer_engine.pytorch.utils import (
cast_if_needed, cast_if_needed,
get_default_init_method, get_default_init_method,
torch_get_autocast_gpu_dtype,
) )
from transformer_engine.pytorch.constants import ( from transformer_engine.pytorch.constants import (
AttnMaskTypes, AttnMaskTypes,
...@@ -435,9 +437,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -435,9 +437,7 @@ class TransformerLayer(torch.nn.Module):
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None
# Set bias+dropout+add fusion grad_enable execution handler. # Set bias+dropout+add fusion grad_enable execution handler.
TORCH_MAJOR = int(torch.__version__.split(".")[0]) use_nvfuser = torch_version() >= (1, 10, 0) and torch_version() < (2, 2, 0)
TORCH_MINOR = int(torch.__version__.split(".")[1])
use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10)
self.bias_dropout_add_exec_handler = nullcontext if use_nvfuser else torch.enable_grad self.bias_dropout_add_exec_handler = nullcontext if use_nvfuser else torch.enable_grad
if self.bias_dropout_fusion: if self.bias_dropout_fusion:
...@@ -687,7 +687,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -687,7 +687,7 @@ class TransformerLayer(torch.nn.Module):
# For AMP # For AMP
if torch.is_autocast_enabled(): if torch.is_autocast_enabled():
hidden_states = cast_if_needed(hidden_states, torch.get_autocast_gpu_dtype()) hidden_states = cast_if_needed(hidden_states, torch_get_autocast_gpu_dtype())
# Self attention. # Self attention.
self_attention_outputs = self.self_attention( self_attention_outputs = self.self_attention(
......
...@@ -9,9 +9,10 @@ import math ...@@ -9,9 +9,10 @@ import math
import os import os
from typing import Any, Callable, List, Optional, Tuple, Union from typing import Any, Callable, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
import transformer_engine.pytorch.cpp_extensions as ext import transformer_engine.pytorch.cpp_extensions as ext
from . import torch_version
from ..debug.pytorch.debug_quantization import DebugQuantizedTensor from ..debug.pytorch.debug_quantization import DebugQuantizedTensor
...@@ -596,3 +597,16 @@ def canonicalize_process_group( ...@@ -596,3 +597,16 @@ def canonicalize_process_group(
if group is None: if group is None:
return torch.distributed.distributed_c10d._get_default_group() return torch.distributed.distributed_c10d._get_default_group()
return group return group
def torch_get_autocast_gpu_dtype() -> torch.dtype:
"""Get PyTorch autocast GPU dtype."""
if torch_version() >= (2, 4, 0):
return torch.get_autocast_dtype("cuda")
return torch.get_autocast_gpu_dtype()
if torch_version() >= (2, 4, 0):
gpu_autocast_ctx = functools.partial(torch.amp.autocast, device_type="cuda")
else:
gpu_autocast_ctx = torch.cuda.amp.autocast
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment