Unverified Commit 95dea389 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Add release to deprecation warnings (#447)



Change deprecation warnings
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 06eebf66
...@@ -10,29 +10,39 @@ from ..common.utils import deprecate_wrapper ...@@ -10,29 +10,39 @@ from ..common.utils import deprecate_wrapper
extend_logical_axis_rules = deprecate_wrapper( extend_logical_axis_rules = deprecate_wrapper(
flax.extend_logical_axis_rules, flax.extend_logical_axis_rules,
"extend_logical_axis_rules is moving to transformer_engine.jax.flax module") "extend_logical_axis_rules is moving to transformer_engine.jax.flax module"
" and will be fully removed in the next release (v1.0.0).")
DenseGeneral = deprecate_wrapper(flax.DenseGeneral, DenseGeneral = deprecate_wrapper(flax.DenseGeneral,
"DenseGeneral is moving to transformer_engine.jax.flax module") "DenseGeneral is moving to transformer_engine.jax.flax module"
" and will be fully removed in the next release (v1.0.0).")
LayerNorm = deprecate_wrapper(flax.LayerNorm, LayerNorm = deprecate_wrapper(flax.LayerNorm,
"LayerNorm is moving to transformer_engine.jax.flax module") "LayerNorm is moving to transformer_engine.jax.flax module"
" and will be fully removed in the next release (v1.0.0).")
LayerNormDenseGeneral = deprecate_wrapper( LayerNormDenseGeneral = deprecate_wrapper(
flax.LayerNormDenseGeneral, flax.LayerNormDenseGeneral,
"LayerNormDenseGeneral is moving to transformer_engine.jax.flax module") "LayerNormDenseGeneral is moving to transformer_engine.jax.flax module"
" and will be fully removed in the next release (v1.0.0).")
LayerNormMLP = deprecate_wrapper(flax.LayerNormMLP, LayerNormMLP = deprecate_wrapper(flax.LayerNormMLP,
"LayerNormMLP is moving to transformer_engine.jax.flax module") "LayerNormMLP is moving to transformer_engine.jax.flax module"
" and will be fully removed in the next release (v1.0.0).")
TransformerEngineBase = deprecate_wrapper( TransformerEngineBase = deprecate_wrapper(
flax.TransformerEngineBase, flax.TransformerEngineBase,
"TransformerEngineBase is moving to transformer_engine.jax.flax module") "TransformerEngineBase is moving to transformer_engine.jax.flax module"
" and will be fully removed in the next release (v1.0.0).")
MultiHeadAttention = deprecate_wrapper( MultiHeadAttention = deprecate_wrapper(
flax.MultiHeadAttention, "MultiHeadAttention is moving to transformer_engine.jax.flax module") flax.MultiHeadAttention, "MultiHeadAttention is moving to transformer_engine.jax.flax module"
" and will be fully removed in the next release (v1.0.0).")
RelativePositionBiases = deprecate_wrapper( RelativePositionBiases = deprecate_wrapper(
flax.RelativePositionBiases, flax.RelativePositionBiases,
"RelativePositionBiases is moving to transformer_engine.jax.flax module") "RelativePositionBiases is moving to transformer_engine.jax.flax module"
" and will be fully removed in the next release (v1.0.0).")
TransformerLayer = deprecate_wrapper( TransformerLayer = deprecate_wrapper(
flax.TransformerLayer, "TransformerLayer is moving to transformer_engine.jax.flax module") flax.TransformerLayer, "TransformerLayer is moving to transformer_engine.jax.flax module"
" and will be fully removed in the next release (v1.0.0).")
TransformerLayerType = deprecate_wrapper( TransformerLayerType = deprecate_wrapper(
flax.TransformerLayerType, flax.TransformerLayerType,
"TransformerLayerType is moving to transformer_engine.jax.flax module") "TransformerLayerType is moving to transformer_engine.jax.flax module"
" and will be fully removed in the next release (v1.0.0).")
__all__ = [ __all__ = [
'fp8_autocast', 'update_collections', 'update_fp8_metas', 'get_delayed_scaling', 'fp8_autocast', 'update_collections', 'update_fp8_metas', 'get_delayed_scaling',
......
...@@ -366,7 +366,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -366,7 +366,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if isinstance(state, list): if isinstance(state, list):
warnings.warn( warnings.warn(
"This checkpoint format is deprecated and will be" "This checkpoint format is deprecated and will be"
"removed in a future release of Transformer Engine" "removed in the next release (v1.0.0)."
) )
# Retrieve checkpointed items. # Retrieve checkpointed items.
...@@ -412,7 +412,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -412,7 +412,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
else: else:
warnings.warn( warnings.warn(
"This checkpoint format is deprecated and will be" "This checkpoint format is deprecated and will be"
"removed in a future release of Transformer Engine" "removed in the next release (v1.0.0)."
) )
# Load extra items. # Load extra items.
self.fp8_meta.update(state["extra_fp8_variables"]) self.fp8_meta.update(state["extra_fp8_variables"])
......
...@@ -516,7 +516,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -516,7 +516,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
.. warning:: .. warning::
Argument :attr:`skip_weight_param_allocation` is deprecated and will Argument :attr:`skip_weight_param_allocation` is deprecated and will
be fully removed in future releases. be fully removed in the next release (v1.0.0).
Parameters Parameters
---------- ----------
...@@ -624,7 +624,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -624,7 +624,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
if skip_weight_param_allocation: if skip_weight_param_allocation:
warnings.warn( warnings.warn(
"Argument `skip_weight_param_allocation` is deprecated and" "Argument `skip_weight_param_allocation` is deprecated and"
"will be fully removed in future releases. It is ignored" "will be fully removed in the next release (v1.0.0). It is ignored"
"starting from v0.11.", "starting from v0.11.",
category=DeprecationWarning, category=DeprecationWarning,
) )
...@@ -831,7 +831,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -831,7 +831,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
.. warning:: .. warning::
Arguments :attr:`weight` and :attr:`bias` are deprecated and will Arguments :attr:`weight` and :attr:`bias` are deprecated and will
be fully removed in future releases. be fully removed in the next release (v1.0.0).
Parameters Parameters
---------- ----------
...@@ -855,7 +855,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -855,7 +855,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
if weight is not None or bias is not None: if weight is not None or bias is not None:
raise RuntimeError( raise RuntimeError(
"Arguments `weight` and `bias` are deprecated and " "Arguments `weight` and `bias` are deprecated and "
"will be fully removed in future releases." "will be fully removed in the next release (v1.0.0)."
) )
with self.prepare_forward(inp, is_first_microbatch) as inp: with self.prepare_forward(inp, is_first_microbatch) as inp:
......
...@@ -451,7 +451,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -451,7 +451,7 @@ class Linear(TransformerEngineBaseModule):
.. warning:: .. warning::
Argument :attr:`skip_weight_param_allocation` is deprecated and will Argument :attr:`skip_weight_param_allocation` is deprecated and will
be fully removed in future releases. be fully removed in the next release (v1.0.0).
Parameters Parameters
---------- ----------
...@@ -538,7 +538,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -538,7 +538,7 @@ class Linear(TransformerEngineBaseModule):
if skip_weight_param_allocation: if skip_weight_param_allocation:
warnings.warn( warnings.warn(
"Argument `skip_weight_param_allocation` is deprecated and" "Argument `skip_weight_param_allocation` is deprecated and"
"will be fully removed in future releases. It has ignored" "will be fully removed in the next release (v1.0.0). It has ignored"
"starting from v0.11.", "starting from v0.11.",
category=DeprecationWarning, category=DeprecationWarning,
) )
...@@ -706,7 +706,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -706,7 +706,7 @@ class Linear(TransformerEngineBaseModule):
.. warning:: .. warning::
Arguments :attr:`weight` and :attr:`bias` are deprecated and will Arguments :attr:`weight` and :attr:`bias` are deprecated and will
be fully removed in future releases. be fully removed in the next release (v1.0.0).
Parameters Parameters
---------- ----------
...@@ -730,7 +730,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -730,7 +730,7 @@ class Linear(TransformerEngineBaseModule):
if weight is not None or bias is not None: if weight is not None or bias is not None:
raise RuntimeError( raise RuntimeError(
"Arguments `weight` and `bias` are deprecated and " "Arguments `weight` and `bias` are deprecated and "
"will be fully removed in future releases." "will be fully removed in the next release (v1.0.0)."
) )
with self.prepare_forward(inp, is_first_microbatch) as inp: with self.prepare_forward(inp, is_first_microbatch) as inp:
......
...@@ -71,7 +71,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -71,7 +71,7 @@ class TransformerLayer(torch.nn.Module):
.. warning:: .. warning::
Arguments :attr:`attention_softmax_in_fp32` and :attr:`apply_query_key_layer_scaling` Arguments :attr:`attention_softmax_in_fp32` and :attr:`apply_query_key_layer_scaling`
are deprecated and will be fully removed in future releases. are deprecated and will be fully removed in the next release (v1.0.0).
.. note:: .. note::
...@@ -247,7 +247,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -247,7 +247,7 @@ class TransformerLayer(torch.nn.Module):
warnings.warn( warnings.warn(
"Arguments `attention_softmax_in_fp32` and `apply_query_key_layer_scaling`" "Arguments `attention_softmax_in_fp32` and `apply_query_key_layer_scaling`"
"are deprecated and will be fully removed in future releases.", "are deprecated and will be fully removed in the next release (v1.0.0).",
category=DeprecationWarning, category=DeprecationWarning,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment