Unverified Commit daa5e184 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Remove deprecated APIs (#464)


Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 29b4670c
......@@ -6,47 +6,9 @@
from . import flax
from .fp8 import fp8_autocast, update_collections, update_fp8_metas, get_delayed_scaling
from .sharding import MajorShardingType, ShardingResource, ShardingType
from ..common.utils import deprecate_wrapper
extend_logical_axis_rules = deprecate_wrapper(
flax.extend_logical_axis_rules,
"extend_logical_axis_rules is moving to transformer_engine.jax.flax module"
" and will be fully removed in the next release (v1.0.0).")
DenseGeneral = deprecate_wrapper(flax.DenseGeneral,
"DenseGeneral is moving to transformer_engine.jax.flax module"
" and will be fully removed in the next release (v1.0.0).")
LayerNorm = deprecate_wrapper(flax.LayerNorm,
"LayerNorm is moving to transformer_engine.jax.flax module"
" and will be fully removed in the next release (v1.0.0).")
LayerNormDenseGeneral = deprecate_wrapper(
flax.LayerNormDenseGeneral,
"LayerNormDenseGeneral is moving to transformer_engine.jax.flax module"
" and will be fully removed in the next release (v1.0.0).")
LayerNormMLP = deprecate_wrapper(flax.LayerNormMLP,
"LayerNormMLP is moving to transformer_engine.jax.flax module"
" and will be fully removed in the next release (v1.0.0).")
TransformerEngineBase = deprecate_wrapper(
flax.TransformerEngineBase,
"TransformerEngineBase is moving to transformer_engine.jax.flax module"
" and will be fully removed in the next release (v1.0.0).")
MultiHeadAttention = deprecate_wrapper(
flax.MultiHeadAttention, "MultiHeadAttention is moving to transformer_engine.jax.flax module"
" and will be fully removed in the next release (v1.0.0).")
RelativePositionBiases = deprecate_wrapper(
flax.RelativePositionBiases,
"RelativePositionBiases is moving to transformer_engine.jax.flax module"
" and will be fully removed in the next release (v1.0.0).")
TransformerLayer = deprecate_wrapper(
flax.TransformerLayer, "TransformerLayer is moving to transformer_engine.jax.flax module"
" and will be fully removed in the next release (v1.0.0).")
TransformerLayerType = deprecate_wrapper(
flax.TransformerLayerType,
"TransformerLayerType is moving to transformer_engine.jax.flax module"
" and will be fully removed in the next release (v1.0.0).")
__all__ = [
'fp8_autocast', 'update_collections', 'update_fp8_metas', 'get_delayed_scaling',
'MajorShardingType', 'ShardingResource', 'ShardingType', 'flax', 'praxis', 'DenseGeneral',
'LayerNorm', 'LayerNormDenseGeneral', 'LayerNormMLP', 'TransformerEngineBase',
'MultiHeadAttention', 'RelativePositionBiases', 'TransformerLayer', 'TransformerLayerType'
'MajorShardingType', 'ShardingResource', 'ShardingType', 'flax', 'praxis',
]
......@@ -7,3 +7,9 @@ from .module import LayerNormDenseGeneral, LayerNormMLP, TransformerEngineBase
from .transformer import extend_logical_axis_rules
from .transformer import MultiHeadAttention, RelativePositionBiases
from .transformer import TransformerLayer, TransformerLayerType
__all__ = [
'DenseGeneral', 'LayerNorm', 'LayerNormDenseGeneral', 'LayerNormMLP',
'TransformerEngineBase', 'extend_logical_axis_rules', 'MultiHeadAttention',
'RelativePositionBiases', 'TransformerLayer', 'TransformerLayerType',
]
......@@ -334,9 +334,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
"""Save before checkpointing."""
state = None
# Maintain backward compatibility.
fp8_checkpoint = "fp8_checkpoint" in self.fp8_meta and self.fp8_meta["fp8_checkpoint"]
fp8_checkpoint = fp8_checkpoint or self.fp8 or self.fp8_calibration
fp8_checkpoint = self.fp8_meta["fp8_checkpoint"] or self.fp8 or self.fp8_calibration
if fp8_checkpoint:
state = {}
......@@ -369,44 +367,13 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if state is None:
return
# Maintain backward compatibility with v0.2.0 and older.
if isinstance(state, list):
warnings.warn(
"This checkpoint format is deprecated and will be"
"removed in the next release (v1.0.0)."
)
# Retrieve checkpointed items.
scale_fwd = state[0]
amax_history_fwd = state[1]
scale_bwd = state[2]
amax_history_bwd = state[3]
self.fp8_meta["recipe"].amax_history_len = amax_history_fwd.shape[0]
self.fp8_meta["num_gemms"] = (
amax_history_fwd.shape[1] // 2
) # Two FWD tensors per GEMM
# Initialize before loading
self.init_fp8_meta_tensors()
self.fp8_meta["scaling_fwd"].scale.copy_(scale_fwd)
self.fp8_meta["scaling_fwd"].amax_history.copy_(amax_history_fwd)
self.fp8_meta["scaling_bwd"].scale.copy_(scale_bwd)
self.fp8_meta["scaling_bwd"].amax_history.copy_(amax_history_bwd)
# Restore global FP8 buffer state.
FP8GlobalStateManager.set_global_fp8_buffer_checkpoint(state[4])
self.fp8_meta["update_amax_and_scale_fwd"] = state[5]
self.fp8_meta["global_fp8_buffer_pos_fwd"] = state[6]
self.fp8_meta["global_fp8_buffer_pos_bwd"] = state[7]
self.fp8_meta["autocast_id_fwd"] = state[8]
self.fp8_meta["autocast_id_bwd"] = state[9]
return
if isinstance(state, torch.Tensor):
state = pickle.loads(state.detach().cpu().numpy().tobytes())
elif isinstance(state, io.BytesIO):
state.seek(0)
state = torch.load(state, map_location='cuda')
else:
raise RuntimeError("Unsupported checkpoint format.")
if state is None:
return
......@@ -414,13 +381,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Restore global FP8 amax buffer.
FP8GlobalStateManager.set_global_fp8_buffer_checkpoint(state["global_fp8_buffer"])
# Restore global FP8 state.
if "global_fp8_state" in state:
FP8GlobalStateManager.set_global_fp8_state_checkpoint(state["global_fp8_state"])
else:
warnings.warn(
"This checkpoint format is deprecated and will be"
"removed in the next release (v1.0.0)."
)
FP8GlobalStateManager.set_global_fp8_state_checkpoint(state["global_fp8_state"])
# Load extra items.
self.fp8_meta.update(state["extra_fp8_variables"])
self.fp8_meta["recipe"].amax_history_len = state["amax_history_fwd"].shape[0]
......@@ -433,18 +395,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.fp8_meta["scaling_fwd"].amax_history.copy_(state["amax_history_fwd"])
self.fp8_meta["scaling_bwd"].scale.copy_(state["scale_bwd"])
self.fp8_meta["scaling_bwd"].amax_history.copy_(state["amax_history_bwd"])
# Backwards compatibility: compute scale inv if it wasn't saved in the extra state.
if "scale_inv_fwd" not in state or "scale_inv_bwd" not in state:
assert (
"scale_inv_fwd" not in state and "scale_inv_bwd" not in state
), "Invalid state, began saving scale_inv_fwd and scale_inv_bwd at the same time"
self.fp8_meta["scaling_fwd"].scale_inv.copy_(1.0/state["scale_fwd"])
self.fp8_meta["scaling_bwd"].scale_inv.copy_(1.0/state["scale_bwd"])
else:
self.fp8_meta["scaling_fwd"].scale_inv.copy_(state["scale_inv_fwd"])
self.fp8_meta["scaling_bwd"].scale_inv.copy_(state["scale_inv_bwd"])
self.fp8_meta["scaling_fwd"].scale_inv.copy_(state["scale_inv_fwd"])
self.fp8_meta["scaling_bwd"].scale_inv.copy_(state["scale_inv_bwd"])
def set_activation_dtype(self, inp: torch.Tensor) -> None:
"""Get activation data type for AMP."""
......
......@@ -4,7 +4,7 @@
"""LayerNorm API"""
import os
from typing import Union, Tuple, Any, Mapping, Optional
from typing import Union, Tuple, Optional
import torch
from torch.nn.parameter import Parameter
......@@ -148,23 +148,6 @@ class LayerNorm(torch.nn.Module):
self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
def load_state_dict(
self,
state_dict: Mapping[str, Any],
strict: bool = True,
) -> None:
"""Override PyTorch loader to maintain backward compatibility
with previous version of LayerNorm parameter names.
"""
if "layer_norm_weight" in state_dict:
state_dict["weight"] = state_dict["layer_norm_weight"]
del state_dict["layer_norm_weight"]
if "layer_norm_bias" in state_dict:
state_dict["bias"] = state_dict["layer_norm_bias"]
del state_dict["layer_norm_bias"]
super().load_state_dict(state_dict, strict)
def reset_layer_norm_parameters(self) -> None:
"""Init LN params"""
if not self.zero_centered_gamma:
......@@ -173,16 +156,9 @@ class LayerNorm(torch.nn.Module):
init.zeros_(self.weight)
init.zeros_(self.bias)
@no_torch_dynamo
def forward(self, inp: torch.Tensor) -> torch.Tensor:
"""LayerNorm FWD"""
# Maintain backward compatibility.
if hasattr(self, "layer_norm_weight"):
setattr(self, "weight", self.layer_norm_weight)
if hasattr(self, "layer_norm_bias"):
setattr(self, "bias", self.layer_norm_bias)
# Set the activation type for AMP.
TransformerEngineBaseModule.set_activation_dtype(self, inp)
......
......@@ -551,11 +551,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
r"""
Applies layer normalization followed by linear transformation to the incoming data.
.. warning::
Argument :attr:`skip_weight_param_allocation` is deprecated and will
be fully removed in the next release (v1.0.0).
Parameters
----------
in_features : int
......@@ -649,7 +644,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
params_dtype: Optional[torch.dtype] = None,
parallel_mode: Optional[str] = None,
return_layernorm_output: bool = False,
skip_weight_param_allocation: bool = False,
parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None,
zero_centered_gamma: bool = False,
ub_bulk_wgrad: bool = False,
......@@ -660,14 +654,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
) -> None:
super().__init__()
if skip_weight_param_allocation:
warnings.warn(
"Argument `skip_weight_param_allocation` is deprecated and"
"will be fully removed in the next release (v1.0.0). It is ignored"
"starting from v0.11.",
category=DeprecationWarning,
)
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
self.in_features = in_features
self.out_features = out_features
......@@ -866,18 +852,11 @@ class LayerNormLinear(TransformerEngineBaseModule):
def forward(
self,
inp: torch.Tensor,
weight: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
is_first_microbatch: Optional[bool] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
"""
Apply layer normalization to the input followed by a linear transformation.
.. warning::
Arguments :attr:`weight` and :attr:`bias` are deprecated and will
be fully removed in the next release (v1.0.0).
Parameters
----------
inp : torch.Tensor
......@@ -897,12 +876,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
produced)
"""
if weight is not None or bias is not None:
raise RuntimeError(
"Arguments `weight` and `bias` are deprecated and "
"will be fully removed in the next release (v1.0.0)."
)
with self.prepare_forward(inp, is_first_microbatch) as inp:
bias_tensor = (
self.bias if self.parameters_split is None
......
......@@ -479,11 +479,6 @@ class Linear(TransformerEngineBaseModule):
On NVIDIA GPUs it is a drop-in replacement for `torch.nn.Linear`.
.. warning::
Argument :attr:`skip_weight_param_allocation` is deprecated and will
be fully removed in the next release (v1.0.0).
Parameters
----------
in_features : int
......@@ -558,7 +553,6 @@ class Linear(TransformerEngineBaseModule):
return_bias: bool = False,
params_dtype: Optional[torch.dtype] = None,
parallel_mode: Optional[str] = None,
skip_weight_param_allocation: bool = False,
parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None,
ub_split_rs: bool = False,
ub_split_ag: bool = False,
......@@ -568,14 +562,6 @@ class Linear(TransformerEngineBaseModule):
) -> None:
super().__init__()
if skip_weight_param_allocation:
warnings.warn(
"Argument `skip_weight_param_allocation` is deprecated and"
"will be fully removed in the next release (v1.0.0). It has ignored"
"starting from v0.11.",
category=DeprecationWarning,
)
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
self.in_features = in_features
self.out_features = out_features
......@@ -736,18 +722,11 @@ class Linear(TransformerEngineBaseModule):
def forward(
self,
inp: torch.Tensor,
weight: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
is_first_microbatch: Optional[bool] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
"""
Apply the linear transformation to the input.
.. warning::
Arguments :attr:`weight` and :attr:`bias` are deprecated and will
be fully removed in the next release (v1.0.0).
Parameters
----------
inp : torch.Tensor
......@@ -767,12 +746,6 @@ class Linear(TransformerEngineBaseModule):
produced)
"""
if weight is not None or bias is not None:
raise RuntimeError(
"Arguments `weight` and `bias` are deprecated and "
"will be fully removed in the next release (v1.0.0)."
)
with self.prepare_forward(inp, is_first_microbatch) as inp:
bias_tensor = (
self.bias if self.parameters_split is None
......
......@@ -68,11 +68,6 @@ class TransformerLayer(torch.nn.Module):
TransformerLayer is made up of an attention block and a feedforward network (MLP).
This standard layer is based on the paper "Attention Is All You Need".
.. warning::
Arguments :attr:`attention_softmax_in_fp32` and :attr:`apply_query_key_layer_scaling`
are deprecated and will be fully removed in the next release (v1.0.0).
.. note::
Argument :attr:`attention_mask` will be ignored in the `forward` call when
......@@ -224,8 +219,6 @@ class TransformerLayer(torch.nn.Module):
params_dtype: Optional[torch.dtype] = None,
get_rng_state_tracker: Optional[Callable] = None,
fuse_wgrad_accumulation: bool = False,
apply_query_key_layer_scaling: bool = False, # pylint: disable=unused-argument
attention_softmax_in_fp32: bool = True, # pylint: disable=unused-argument
seq_length: Optional[int] = None,
micro_batch_size: Optional[int] = None,
sequence_parallel: bool = False,
......@@ -245,12 +238,6 @@ class TransformerLayer(torch.nn.Module):
) -> None:
super().__init__()
warnings.warn(
"Arguments `attention_softmax_in_fp32` and `apply_query_key_layer_scaling`"
"are deprecated and will be fully removed in the next release (v1.0.0).",
category=DeprecationWarning,
)
if ub_tp_comm_overlap:
assert (
tex.userbuf_comm_available()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment