"...git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "07b750a2a1a8a3001fc64e5635f8ae985b12154b"
Unverified Commit daa5e184 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Remove deprecated APIs (#464)


Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 29b4670c
...@@ -6,47 +6,9 @@ ...@@ -6,47 +6,9 @@
from . import flax from . import flax
from .fp8 import fp8_autocast, update_collections, update_fp8_metas, get_delayed_scaling from .fp8 import fp8_autocast, update_collections, update_fp8_metas, get_delayed_scaling
from .sharding import MajorShardingType, ShardingResource, ShardingType from .sharding import MajorShardingType, ShardingResource, ShardingType
from ..common.utils import deprecate_wrapper
extend_logical_axis_rules = deprecate_wrapper(
flax.extend_logical_axis_rules,
"extend_logical_axis_rules is moving to transformer_engine.jax.flax module"
" and will be fully removed in the next release (v1.0.0).")
DenseGeneral = deprecate_wrapper(flax.DenseGeneral,
"DenseGeneral is moving to transformer_engine.jax.flax module"
" and will be fully removed in the next release (v1.0.0).")
LayerNorm = deprecate_wrapper(flax.LayerNorm,
"LayerNorm is moving to transformer_engine.jax.flax module"
" and will be fully removed in the next release (v1.0.0).")
LayerNormDenseGeneral = deprecate_wrapper(
flax.LayerNormDenseGeneral,
"LayerNormDenseGeneral is moving to transformer_engine.jax.flax module"
" and will be fully removed in the next release (v1.0.0).")
LayerNormMLP = deprecate_wrapper(flax.LayerNormMLP,
"LayerNormMLP is moving to transformer_engine.jax.flax module"
" and will be fully removed in the next release (v1.0.0).")
TransformerEngineBase = deprecate_wrapper(
flax.TransformerEngineBase,
"TransformerEngineBase is moving to transformer_engine.jax.flax module"
" and will be fully removed in the next release (v1.0.0).")
MultiHeadAttention = deprecate_wrapper(
flax.MultiHeadAttention, "MultiHeadAttention is moving to transformer_engine.jax.flax module"
" and will be fully removed in the next release (v1.0.0).")
RelativePositionBiases = deprecate_wrapper(
flax.RelativePositionBiases,
"RelativePositionBiases is moving to transformer_engine.jax.flax module"
" and will be fully removed in the next release (v1.0.0).")
TransformerLayer = deprecate_wrapper(
flax.TransformerLayer, "TransformerLayer is moving to transformer_engine.jax.flax module"
" and will be fully removed in the next release (v1.0.0).")
TransformerLayerType = deprecate_wrapper(
flax.TransformerLayerType,
"TransformerLayerType is moving to transformer_engine.jax.flax module"
" and will be fully removed in the next release (v1.0.0).")
__all__ = [ __all__ = [
'fp8_autocast', 'update_collections', 'update_fp8_metas', 'get_delayed_scaling', 'fp8_autocast', 'update_collections', 'update_fp8_metas', 'get_delayed_scaling',
'MajorShardingType', 'ShardingResource', 'ShardingType', 'flax', 'praxis', 'DenseGeneral', 'MajorShardingType', 'ShardingResource', 'ShardingType', 'flax', 'praxis',
'LayerNorm', 'LayerNormDenseGeneral', 'LayerNormMLP', 'TransformerEngineBase',
'MultiHeadAttention', 'RelativePositionBiases', 'TransformerLayer', 'TransformerLayerType'
] ]
...@@ -7,3 +7,9 @@ from .module import LayerNormDenseGeneral, LayerNormMLP, TransformerEngineBase ...@@ -7,3 +7,9 @@ from .module import LayerNormDenseGeneral, LayerNormMLP, TransformerEngineBase
from .transformer import extend_logical_axis_rules from .transformer import extend_logical_axis_rules
from .transformer import MultiHeadAttention, RelativePositionBiases from .transformer import MultiHeadAttention, RelativePositionBiases
from .transformer import TransformerLayer, TransformerLayerType from .transformer import TransformerLayer, TransformerLayerType
__all__ = [
'DenseGeneral', 'LayerNorm', 'LayerNormDenseGeneral', 'LayerNormMLP',
'TransformerEngineBase', 'extend_logical_axis_rules', 'MultiHeadAttention',
'RelativePositionBiases', 'TransformerLayer', 'TransformerLayerType',
]
...@@ -334,9 +334,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -334,9 +334,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
"""Save before checkpointing.""" """Save before checkpointing."""
state = None state = None
# Maintain backward compatibility. fp8_checkpoint = self.fp8_meta["fp8_checkpoint"] or self.fp8 or self.fp8_calibration
fp8_checkpoint = "fp8_checkpoint" in self.fp8_meta and self.fp8_meta["fp8_checkpoint"]
fp8_checkpoint = fp8_checkpoint or self.fp8 or self.fp8_calibration
if fp8_checkpoint: if fp8_checkpoint:
state = {} state = {}
...@@ -369,44 +367,13 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -369,44 +367,13 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if state is None: if state is None:
return return
# Maintain backward compatibility with v0.2.0 and older.
if isinstance(state, list):
warnings.warn(
"This checkpoint format is deprecated and will be"
"removed in the next release (v1.0.0)."
)
# Retrieve checkpointed items.
scale_fwd = state[0]
amax_history_fwd = state[1]
scale_bwd = state[2]
amax_history_bwd = state[3]
self.fp8_meta["recipe"].amax_history_len = amax_history_fwd.shape[0]
self.fp8_meta["num_gemms"] = (
amax_history_fwd.shape[1] // 2
) # Two FWD tensors per GEMM
# Initialize before loading
self.init_fp8_meta_tensors()
self.fp8_meta["scaling_fwd"].scale.copy_(scale_fwd)
self.fp8_meta["scaling_fwd"].amax_history.copy_(amax_history_fwd)
self.fp8_meta["scaling_bwd"].scale.copy_(scale_bwd)
self.fp8_meta["scaling_bwd"].amax_history.copy_(amax_history_bwd)
# Restore global FP8 buffer state.
FP8GlobalStateManager.set_global_fp8_buffer_checkpoint(state[4])
self.fp8_meta["update_amax_and_scale_fwd"] = state[5]
self.fp8_meta["global_fp8_buffer_pos_fwd"] = state[6]
self.fp8_meta["global_fp8_buffer_pos_bwd"] = state[7]
self.fp8_meta["autocast_id_fwd"] = state[8]
self.fp8_meta["autocast_id_bwd"] = state[9]
return
if isinstance(state, torch.Tensor): if isinstance(state, torch.Tensor):
state = pickle.loads(state.detach().cpu().numpy().tobytes()) state = pickle.loads(state.detach().cpu().numpy().tobytes())
elif isinstance(state, io.BytesIO): elif isinstance(state, io.BytesIO):
state.seek(0) state.seek(0)
state = torch.load(state, map_location='cuda') state = torch.load(state, map_location='cuda')
else:
raise RuntimeError("Unsupported checkpoint format.")
if state is None: if state is None:
return return
...@@ -414,13 +381,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -414,13 +381,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Restore global FP8 amax buffer. # Restore global FP8 amax buffer.
FP8GlobalStateManager.set_global_fp8_buffer_checkpoint(state["global_fp8_buffer"]) FP8GlobalStateManager.set_global_fp8_buffer_checkpoint(state["global_fp8_buffer"])
# Restore global FP8 state. # Restore global FP8 state.
if "global_fp8_state" in state: FP8GlobalStateManager.set_global_fp8_state_checkpoint(state["global_fp8_state"])
FP8GlobalStateManager.set_global_fp8_state_checkpoint(state["global_fp8_state"])
else:
warnings.warn(
"This checkpoint format is deprecated and will be"
"removed in the next release (v1.0.0)."
)
# Load extra items. # Load extra items.
self.fp8_meta.update(state["extra_fp8_variables"]) self.fp8_meta.update(state["extra_fp8_variables"])
self.fp8_meta["recipe"].amax_history_len = state["amax_history_fwd"].shape[0] self.fp8_meta["recipe"].amax_history_len = state["amax_history_fwd"].shape[0]
...@@ -433,18 +395,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -433,18 +395,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.fp8_meta["scaling_fwd"].amax_history.copy_(state["amax_history_fwd"]) self.fp8_meta["scaling_fwd"].amax_history.copy_(state["amax_history_fwd"])
self.fp8_meta["scaling_bwd"].scale.copy_(state["scale_bwd"]) self.fp8_meta["scaling_bwd"].scale.copy_(state["scale_bwd"])
self.fp8_meta["scaling_bwd"].amax_history.copy_(state["amax_history_bwd"]) self.fp8_meta["scaling_bwd"].amax_history.copy_(state["amax_history_bwd"])
self.fp8_meta["scaling_fwd"].scale_inv.copy_(state["scale_inv_fwd"])
# Backwards compatibility: compute scale inv if it wasn't saved in the extra state. self.fp8_meta["scaling_bwd"].scale_inv.copy_(state["scale_inv_bwd"])
if "scale_inv_fwd" not in state or "scale_inv_bwd" not in state:
assert (
"scale_inv_fwd" not in state and "scale_inv_bwd" not in state
), "Invalid state, began saving scale_inv_fwd and scale_inv_bwd at the same time"
self.fp8_meta["scaling_fwd"].scale_inv.copy_(1.0/state["scale_fwd"])
self.fp8_meta["scaling_bwd"].scale_inv.copy_(1.0/state["scale_bwd"])
else:
self.fp8_meta["scaling_fwd"].scale_inv.copy_(state["scale_inv_fwd"])
self.fp8_meta["scaling_bwd"].scale_inv.copy_(state["scale_inv_bwd"])
def set_activation_dtype(self, inp: torch.Tensor) -> None: def set_activation_dtype(self, inp: torch.Tensor) -> None:
"""Get activation data type for AMP.""" """Get activation data type for AMP."""
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
"""LayerNorm API""" """LayerNorm API"""
import os import os
from typing import Union, Tuple, Any, Mapping, Optional from typing import Union, Tuple, Optional
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
...@@ -148,23 +148,6 @@ class LayerNorm(torch.nn.Module): ...@@ -148,23 +148,6 @@ class LayerNorm(torch.nn.Module):
self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0")) self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
def load_state_dict(
self,
state_dict: Mapping[str, Any],
strict: bool = True,
) -> None:
"""Override PyTorch loader to maintain backward compatibility
with previous version of LayerNorm parameter names.
"""
if "layer_norm_weight" in state_dict:
state_dict["weight"] = state_dict["layer_norm_weight"]
del state_dict["layer_norm_weight"]
if "layer_norm_bias" in state_dict:
state_dict["bias"] = state_dict["layer_norm_bias"]
del state_dict["layer_norm_bias"]
super().load_state_dict(state_dict, strict)
def reset_layer_norm_parameters(self) -> None: def reset_layer_norm_parameters(self) -> None:
"""Init LN params""" """Init LN params"""
if not self.zero_centered_gamma: if not self.zero_centered_gamma:
...@@ -173,16 +156,9 @@ class LayerNorm(torch.nn.Module): ...@@ -173,16 +156,9 @@ class LayerNorm(torch.nn.Module):
init.zeros_(self.weight) init.zeros_(self.weight)
init.zeros_(self.bias) init.zeros_(self.bias)
@no_torch_dynamo @no_torch_dynamo
def forward(self, inp: torch.Tensor) -> torch.Tensor: def forward(self, inp: torch.Tensor) -> torch.Tensor:
"""LayerNorm FWD""" """LayerNorm FWD"""
# Maintain backward compatibility.
if hasattr(self, "layer_norm_weight"):
setattr(self, "weight", self.layer_norm_weight)
if hasattr(self, "layer_norm_bias"):
setattr(self, "bias", self.layer_norm_bias)
# Set the activation type for AMP. # Set the activation type for AMP.
TransformerEngineBaseModule.set_activation_dtype(self, inp) TransformerEngineBaseModule.set_activation_dtype(self, inp)
......
...@@ -551,11 +551,6 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -551,11 +551,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
r""" r"""
Applies layer normalization followed by linear transformation to the incoming data. Applies layer normalization followed by linear transformation to the incoming data.
.. warning::
Argument :attr:`skip_weight_param_allocation` is deprecated and will
be fully removed in the next release (v1.0.0).
Parameters Parameters
---------- ----------
in_features : int in_features : int
...@@ -649,7 +644,6 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -649,7 +644,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
parallel_mode: Optional[str] = None, parallel_mode: Optional[str] = None,
return_layernorm_output: bool = False, return_layernorm_output: bool = False,
skip_weight_param_allocation: bool = False,
parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None, parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None,
zero_centered_gamma: bool = False, zero_centered_gamma: bool = False,
ub_bulk_wgrad: bool = False, ub_bulk_wgrad: bool = False,
...@@ -660,14 +654,6 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -660,14 +654,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
) -> None: ) -> None:
super().__init__() super().__init__()
if skip_weight_param_allocation:
warnings.warn(
"Argument `skip_weight_param_allocation` is deprecated and"
"will be fully removed in the next release (v1.0.0). It is ignored"
"starting from v0.11.",
category=DeprecationWarning,
)
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
self.in_features = in_features self.in_features = in_features
self.out_features = out_features self.out_features = out_features
...@@ -866,18 +852,11 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -866,18 +852,11 @@ class LayerNormLinear(TransformerEngineBaseModule):
def forward( def forward(
self, self,
inp: torch.Tensor, inp: torch.Tensor,
weight: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
is_first_microbatch: Optional[bool] = None, is_first_microbatch: Optional[bool] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
""" """
Apply layer normalization to the input followed by a linear transformation. Apply layer normalization to the input followed by a linear transformation.
.. warning::
Arguments :attr:`weight` and :attr:`bias` are deprecated and will
be fully removed in the next release (v1.0.0).
Parameters Parameters
---------- ----------
inp : torch.Tensor inp : torch.Tensor
...@@ -897,12 +876,6 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -897,12 +876,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
produced) produced)
""" """
if weight is not None or bias is not None:
raise RuntimeError(
"Arguments `weight` and `bias` are deprecated and "
"will be fully removed in the next release (v1.0.0)."
)
with self.prepare_forward(inp, is_first_microbatch) as inp: with self.prepare_forward(inp, is_first_microbatch) as inp:
bias_tensor = ( bias_tensor = (
self.bias if self.parameters_split is None self.bias if self.parameters_split is None
......
...@@ -479,11 +479,6 @@ class Linear(TransformerEngineBaseModule): ...@@ -479,11 +479,6 @@ class Linear(TransformerEngineBaseModule):
On NVIDIA GPUs it is a drop-in replacement for `torch.nn.Linear`. On NVIDIA GPUs it is a drop-in replacement for `torch.nn.Linear`.
.. warning::
Argument :attr:`skip_weight_param_allocation` is deprecated and will
be fully removed in the next release (v1.0.0).
Parameters Parameters
---------- ----------
in_features : int in_features : int
...@@ -558,7 +553,6 @@ class Linear(TransformerEngineBaseModule): ...@@ -558,7 +553,6 @@ class Linear(TransformerEngineBaseModule):
return_bias: bool = False, return_bias: bool = False,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
parallel_mode: Optional[str] = None, parallel_mode: Optional[str] = None,
skip_weight_param_allocation: bool = False,
parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None, parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None,
ub_split_rs: bool = False, ub_split_rs: bool = False,
ub_split_ag: bool = False, ub_split_ag: bool = False,
...@@ -568,14 +562,6 @@ class Linear(TransformerEngineBaseModule): ...@@ -568,14 +562,6 @@ class Linear(TransformerEngineBaseModule):
) -> None: ) -> None:
super().__init__() super().__init__()
if skip_weight_param_allocation:
warnings.warn(
"Argument `skip_weight_param_allocation` is deprecated and"
"will be fully removed in the next release (v1.0.0). It has ignored"
"starting from v0.11.",
category=DeprecationWarning,
)
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
self.in_features = in_features self.in_features = in_features
self.out_features = out_features self.out_features = out_features
...@@ -736,18 +722,11 @@ class Linear(TransformerEngineBaseModule): ...@@ -736,18 +722,11 @@ class Linear(TransformerEngineBaseModule):
def forward( def forward(
self, self,
inp: torch.Tensor, inp: torch.Tensor,
weight: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
is_first_microbatch: Optional[bool] = None, is_first_microbatch: Optional[bool] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
""" """
Apply the linear transformation to the input. Apply the linear transformation to the input.
.. warning::
Arguments :attr:`weight` and :attr:`bias` are deprecated and will
be fully removed in the next release (v1.0.0).
Parameters Parameters
---------- ----------
inp : torch.Tensor inp : torch.Tensor
...@@ -767,12 +746,6 @@ class Linear(TransformerEngineBaseModule): ...@@ -767,12 +746,6 @@ class Linear(TransformerEngineBaseModule):
produced) produced)
""" """
if weight is not None or bias is not None:
raise RuntimeError(
"Arguments `weight` and `bias` are deprecated and "
"will be fully removed in the next release (v1.0.0)."
)
with self.prepare_forward(inp, is_first_microbatch) as inp: with self.prepare_forward(inp, is_first_microbatch) as inp:
bias_tensor = ( bias_tensor = (
self.bias if self.parameters_split is None self.bias if self.parameters_split is None
......
...@@ -68,11 +68,6 @@ class TransformerLayer(torch.nn.Module): ...@@ -68,11 +68,6 @@ class TransformerLayer(torch.nn.Module):
TransformerLayer is made up of an attention block and a feedforward network (MLP). TransformerLayer is made up of an attention block and a feedforward network (MLP).
This standard layer is based on the paper "Attention Is All You Need". This standard layer is based on the paper "Attention Is All You Need".
.. warning::
Arguments :attr:`attention_softmax_in_fp32` and :attr:`apply_query_key_layer_scaling`
are deprecated and will be fully removed in the next release (v1.0.0).
.. note:: .. note::
Argument :attr:`attention_mask` will be ignored in the `forward` call when Argument :attr:`attention_mask` will be ignored in the `forward` call when
...@@ -224,8 +219,6 @@ class TransformerLayer(torch.nn.Module): ...@@ -224,8 +219,6 @@ class TransformerLayer(torch.nn.Module):
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
get_rng_state_tracker: Optional[Callable] = None, get_rng_state_tracker: Optional[Callable] = None,
fuse_wgrad_accumulation: bool = False, fuse_wgrad_accumulation: bool = False,
apply_query_key_layer_scaling: bool = False, # pylint: disable=unused-argument
attention_softmax_in_fp32: bool = True, # pylint: disable=unused-argument
seq_length: Optional[int] = None, seq_length: Optional[int] = None,
micro_batch_size: Optional[int] = None, micro_batch_size: Optional[int] = None,
sequence_parallel: bool = False, sequence_parallel: bool = False,
...@@ -245,12 +238,6 @@ class TransformerLayer(torch.nn.Module): ...@@ -245,12 +238,6 @@ class TransformerLayer(torch.nn.Module):
) -> None: ) -> None:
super().__init__() super().__init__()
warnings.warn(
"Arguments `attention_softmax_in_fp32` and `apply_query_key_layer_scaling`"
"are deprecated and will be fully removed in the next release (v1.0.0).",
category=DeprecationWarning,
)
if ub_tp_comm_overlap: if ub_tp_comm_overlap:
assert ( assert (
tex.userbuf_comm_available() tex.userbuf_comm_available()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment