Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
daa5e184
Unverified
Commit
daa5e184
authored
Oct 10, 2023
by
Kirthi Shankar Sivamani
Committed by
GitHub
Oct 10, 2023
Browse files
Remove deprecated APIs (#464)
Signed-off-by:
Kirthi Shankar Sivamani
<
ksivamani@nvidia.com
>
parent
29b4670c
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
15 additions
and
186 deletions
+15
-186
transformer_engine/jax/__init__.py
transformer_engine/jax/__init__.py
+1
-39
transformer_engine/jax/flax/__init__.py
transformer_engine/jax/flax/__init__.py
+6
-0
transformer_engine/pytorch/module/base.py
transformer_engine/pytorch/module/base.py
+7
-55
transformer_engine/pytorch/module/layernorm.py
transformer_engine/pytorch/module/layernorm.py
+1
-25
transformer_engine/pytorch/module/layernorm_linear.py
transformer_engine/pytorch/module/layernorm_linear.py
+0
-27
transformer_engine/pytorch/module/linear.py
transformer_engine/pytorch/module/linear.py
+0
-27
transformer_engine/pytorch/transformer.py
transformer_engine/pytorch/transformer.py
+0
-13
No files found.
transformer_engine/jax/__init__.py
View file @
daa5e184
...
...
@@ -6,47 +6,9 @@
from
.
import
flax
from
.fp8
import
fp8_autocast
,
update_collections
,
update_fp8_metas
,
get_delayed_scaling
from
.sharding
import
MajorShardingType
,
ShardingResource
,
ShardingType
from
..common.utils
import
deprecate_wrapper
extend_logical_axis_rules
=
deprecate_wrapper
(
flax
.
extend_logical_axis_rules
,
"extend_logical_axis_rules is moving to transformer_engine.jax.flax module"
" and will be fully removed in the next release (v1.0.0)."
)
DenseGeneral
=
deprecate_wrapper
(
flax
.
DenseGeneral
,
"DenseGeneral is moving to transformer_engine.jax.flax module"
" and will be fully removed in the next release (v1.0.0)."
)
LayerNorm
=
deprecate_wrapper
(
flax
.
LayerNorm
,
"LayerNorm is moving to transformer_engine.jax.flax module"
" and will be fully removed in the next release (v1.0.0)."
)
LayerNormDenseGeneral
=
deprecate_wrapper
(
flax
.
LayerNormDenseGeneral
,
"LayerNormDenseGeneral is moving to transformer_engine.jax.flax module"
" and will be fully removed in the next release (v1.0.0)."
)
LayerNormMLP
=
deprecate_wrapper
(
flax
.
LayerNormMLP
,
"LayerNormMLP is moving to transformer_engine.jax.flax module"
" and will be fully removed in the next release (v1.0.0)."
)
TransformerEngineBase
=
deprecate_wrapper
(
flax
.
TransformerEngineBase
,
"TransformerEngineBase is moving to transformer_engine.jax.flax module"
" and will be fully removed in the next release (v1.0.0)."
)
MultiHeadAttention
=
deprecate_wrapper
(
flax
.
MultiHeadAttention
,
"MultiHeadAttention is moving to transformer_engine.jax.flax module"
" and will be fully removed in the next release (v1.0.0)."
)
RelativePositionBiases
=
deprecate_wrapper
(
flax
.
RelativePositionBiases
,
"RelativePositionBiases is moving to transformer_engine.jax.flax module"
" and will be fully removed in the next release (v1.0.0)."
)
TransformerLayer
=
deprecate_wrapper
(
flax
.
TransformerLayer
,
"TransformerLayer is moving to transformer_engine.jax.flax module"
" and will be fully removed in the next release (v1.0.0)."
)
TransformerLayerType
=
deprecate_wrapper
(
flax
.
TransformerLayerType
,
"TransformerLayerType is moving to transformer_engine.jax.flax module"
" and will be fully removed in the next release (v1.0.0)."
)
__all__
=
[
'fp8_autocast'
,
'update_collections'
,
'update_fp8_metas'
,
'get_delayed_scaling'
,
'MajorShardingType'
,
'ShardingResource'
,
'ShardingType'
,
'flax'
,
'praxis'
,
'DenseGeneral'
,
'LayerNorm'
,
'LayerNormDenseGeneral'
,
'LayerNormMLP'
,
'TransformerEngineBase'
,
'MultiHeadAttention'
,
'RelativePositionBiases'
,
'TransformerLayer'
,
'TransformerLayerType'
'MajorShardingType'
,
'ShardingResource'
,
'ShardingType'
,
'flax'
,
'praxis'
,
]
transformer_engine/jax/flax/__init__.py
View file @
daa5e184
...
...
@@ -7,3 +7,9 @@ from .module import LayerNormDenseGeneral, LayerNormMLP, TransformerEngineBase
from
.transformer
import
extend_logical_axis_rules
from
.transformer
import
MultiHeadAttention
,
RelativePositionBiases
from
.transformer
import
TransformerLayer
,
TransformerLayerType
__all__
=
[
'DenseGeneral'
,
'LayerNorm'
,
'LayerNormDenseGeneral'
,
'LayerNormMLP'
,
'TransformerEngineBase'
,
'extend_logical_axis_rules'
,
'MultiHeadAttention'
,
'RelativePositionBiases'
,
'TransformerLayer'
,
'TransformerLayerType'
,
]
transformer_engine/pytorch/module/base.py
View file @
daa5e184
...
...
@@ -334,9 +334,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
"""Save before checkpointing."""
state
=
None
# Maintain backward compatibility.
fp8_checkpoint
=
"fp8_checkpoint"
in
self
.
fp8_meta
and
self
.
fp8_meta
[
"fp8_checkpoint"
]
fp8_checkpoint
=
fp8_checkpoint
or
self
.
fp8
or
self
.
fp8_calibration
fp8_checkpoint
=
self
.
fp8_meta
[
"fp8_checkpoint"
]
or
self
.
fp8
or
self
.
fp8_calibration
if
fp8_checkpoint
:
state
=
{}
...
...
@@ -369,44 +367,13 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if
state
is
None
:
return
# Maintain backward compatibility with v0.2.0 and older.
if
isinstance
(
state
,
list
):
warnings
.
warn
(
"This checkpoint format is deprecated and will be"
"removed in the next release (v1.0.0)."
)
# Retrieve checkpointed items.
scale_fwd
=
state
[
0
]
amax_history_fwd
=
state
[
1
]
scale_bwd
=
state
[
2
]
amax_history_bwd
=
state
[
3
]
self
.
fp8_meta
[
"recipe"
].
amax_history_len
=
amax_history_fwd
.
shape
[
0
]
self
.
fp8_meta
[
"num_gemms"
]
=
(
amax_history_fwd
.
shape
[
1
]
//
2
)
# Two FWD tensors per GEMM
# Initialize before loading
self
.
init_fp8_meta_tensors
()
self
.
fp8_meta
[
"scaling_fwd"
].
scale
.
copy_
(
scale_fwd
)
self
.
fp8_meta
[
"scaling_fwd"
].
amax_history
.
copy_
(
amax_history_fwd
)
self
.
fp8_meta
[
"scaling_bwd"
].
scale
.
copy_
(
scale_bwd
)
self
.
fp8_meta
[
"scaling_bwd"
].
amax_history
.
copy_
(
amax_history_bwd
)
# Restore global FP8 buffer state.
FP8GlobalStateManager
.
set_global_fp8_buffer_checkpoint
(
state
[
4
])
self
.
fp8_meta
[
"update_amax_and_scale_fwd"
]
=
state
[
5
]
self
.
fp8_meta
[
"global_fp8_buffer_pos_fwd"
]
=
state
[
6
]
self
.
fp8_meta
[
"global_fp8_buffer_pos_bwd"
]
=
state
[
7
]
self
.
fp8_meta
[
"autocast_id_fwd"
]
=
state
[
8
]
self
.
fp8_meta
[
"autocast_id_bwd"
]
=
state
[
9
]
return
if
isinstance
(
state
,
torch
.
Tensor
):
state
=
pickle
.
loads
(
state
.
detach
().
cpu
().
numpy
().
tobytes
())
elif
isinstance
(
state
,
io
.
BytesIO
):
state
.
seek
(
0
)
state
=
torch
.
load
(
state
,
map_location
=
'cuda'
)
else
:
raise
RuntimeError
(
"Unsupported checkpoint format."
)
if
state
is
None
:
return
...
...
@@ -414,13 +381,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Restore global FP8 amax buffer.
FP8GlobalStateManager
.
set_global_fp8_buffer_checkpoint
(
state
[
"global_fp8_buffer"
])
# Restore global FP8 state.
if
"global_fp8_state"
in
state
:
FP8GlobalStateManager
.
set_global_fp8_state_checkpoint
(
state
[
"global_fp8_state"
])
else
:
warnings
.
warn
(
"This checkpoint format is deprecated and will be"
"removed in the next release (v1.0.0)."
)
FP8GlobalStateManager
.
set_global_fp8_state_checkpoint
(
state
[
"global_fp8_state"
])
# Load extra items.
self
.
fp8_meta
.
update
(
state
[
"extra_fp8_variables"
])
self
.
fp8_meta
[
"recipe"
].
amax_history_len
=
state
[
"amax_history_fwd"
].
shape
[
0
]
...
...
@@ -433,18 +395,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self
.
fp8_meta
[
"scaling_fwd"
].
amax_history
.
copy_
(
state
[
"amax_history_fwd"
])
self
.
fp8_meta
[
"scaling_bwd"
].
scale
.
copy_
(
state
[
"scale_bwd"
])
self
.
fp8_meta
[
"scaling_bwd"
].
amax_history
.
copy_
(
state
[
"amax_history_bwd"
])
# Backwards compatibility: compute scale inv if it wasn't saved in the extra state.
if
"scale_inv_fwd"
not
in
state
or
"scale_inv_bwd"
not
in
state
:
assert
(
"scale_inv_fwd"
not
in
state
and
"scale_inv_bwd"
not
in
state
),
"Invalid state, began saving scale_inv_fwd and scale_inv_bwd at the same time"
self
.
fp8_meta
[
"scaling_fwd"
].
scale_inv
.
copy_
(
1.0
/
state
[
"scale_fwd"
])
self
.
fp8_meta
[
"scaling_bwd"
].
scale_inv
.
copy_
(
1.0
/
state
[
"scale_bwd"
])
else
:
self
.
fp8_meta
[
"scaling_fwd"
].
scale_inv
.
copy_
(
state
[
"scale_inv_fwd"
])
self
.
fp8_meta
[
"scaling_bwd"
].
scale_inv
.
copy_
(
state
[
"scale_inv_bwd"
])
self
.
fp8_meta
[
"scaling_fwd"
].
scale_inv
.
copy_
(
state
[
"scale_inv_fwd"
])
self
.
fp8_meta
[
"scaling_bwd"
].
scale_inv
.
copy_
(
state
[
"scale_inv_bwd"
])
def
set_activation_dtype
(
self
,
inp
:
torch
.
Tensor
)
->
None
:
"""Get activation data type for AMP."""
...
...
transformer_engine/pytorch/module/layernorm.py
View file @
daa5e184
...
...
@@ -4,7 +4,7 @@
"""LayerNorm API"""
import
os
from
typing
import
Union
,
Tuple
,
Any
,
Mapping
,
Optional
from
typing
import
Union
,
Tuple
,
Optional
import
torch
from
torch.nn.parameter
import
Parameter
...
...
@@ -148,23 +148,6 @@ class LayerNorm(torch.nn.Module):
self
.
fwd_ln_sm_margin
=
int
(
os
.
getenv
(
"NVTE_FWD_LAYERNORM_SM_MARGIN"
,
"0"
))
self
.
bwd_ln_sm_margin
=
int
(
os
.
getenv
(
"NVTE_BWD_LAYERNORM_SM_MARGIN"
,
"0"
))
def
load_state_dict
(
self
,
state_dict
:
Mapping
[
str
,
Any
],
strict
:
bool
=
True
,
)
->
None
:
"""Override PyTorch loader to maintain backward compatibility
with previous version of LayerNorm parameter names.
"""
if
"layer_norm_weight"
in
state_dict
:
state_dict
[
"weight"
]
=
state_dict
[
"layer_norm_weight"
]
del
state_dict
[
"layer_norm_weight"
]
if
"layer_norm_bias"
in
state_dict
:
state_dict
[
"bias"
]
=
state_dict
[
"layer_norm_bias"
]
del
state_dict
[
"layer_norm_bias"
]
super
().
load_state_dict
(
state_dict
,
strict
)
def
reset_layer_norm_parameters
(
self
)
->
None
:
"""Init LN params"""
if
not
self
.
zero_centered_gamma
:
...
...
@@ -173,16 +156,9 @@ class LayerNorm(torch.nn.Module):
init
.
zeros_
(
self
.
weight
)
init
.
zeros_
(
self
.
bias
)
@
no_torch_dynamo
def
forward
(
self
,
inp
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""LayerNorm FWD"""
# Maintain backward compatibility.
if
hasattr
(
self
,
"layer_norm_weight"
):
setattr
(
self
,
"weight"
,
self
.
layer_norm_weight
)
if
hasattr
(
self
,
"layer_norm_bias"
):
setattr
(
self
,
"bias"
,
self
.
layer_norm_bias
)
# Set the activation type for AMP.
TransformerEngineBaseModule
.
set_activation_dtype
(
self
,
inp
)
...
...
transformer_engine/pytorch/module/layernorm_linear.py
View file @
daa5e184
...
...
@@ -551,11 +551,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
r
"""
Applies layer normalization followed by linear transformation to the incoming data.
.. warning::
Argument :attr:`skip_weight_param_allocation` is deprecated and will
be fully removed in the next release (v1.0.0).
Parameters
----------
in_features : int
...
...
@@ -649,7 +644,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
parallel_mode
:
Optional
[
str
]
=
None
,
return_layernorm_output
:
bool
=
False
,
skip_weight_param_allocation
:
bool
=
False
,
parameters_split
:
Optional
[
Union
[
Tuple
[
str
,
...],
Dict
[
str
,
int
]]]
=
None
,
zero_centered_gamma
:
bool
=
False
,
ub_bulk_wgrad
:
bool
=
False
,
...
...
@@ -660,14 +654,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
)
->
None
:
super
().
__init__
()
if
skip_weight_param_allocation
:
warnings
.
warn
(
"Argument `skip_weight_param_allocation` is deprecated and"
"will be fully removed in the next release (v1.0.0). It is ignored"
"starting from v0.11."
,
category
=
DeprecationWarning
,
)
params_dtype
=
torch
.
get_default_dtype
()
if
params_dtype
is
None
else
params_dtype
self
.
in_features
=
in_features
self
.
out_features
=
out_features
...
...
@@ -866,18 +852,11 @@ class LayerNormLinear(TransformerEngineBaseModule):
def
forward
(
self
,
inp
:
torch
.
Tensor
,
weight
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
is_first_microbatch
:
Optional
[
bool
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
...]]:
"""
Apply layer normalization to the input followed by a linear transformation.
.. warning::
Arguments :attr:`weight` and :attr:`bias` are deprecated and will
be fully removed in the next release (v1.0.0).
Parameters
----------
inp : torch.Tensor
...
...
@@ -897,12 +876,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
produced)
"""
if
weight
is
not
None
or
bias
is
not
None
:
raise
RuntimeError
(
"Arguments `weight` and `bias` are deprecated and "
"will be fully removed in the next release (v1.0.0)."
)
with
self
.
prepare_forward
(
inp
,
is_first_microbatch
)
as
inp
:
bias_tensor
=
(
self
.
bias
if
self
.
parameters_split
is
None
...
...
transformer_engine/pytorch/module/linear.py
View file @
daa5e184
...
...
@@ -479,11 +479,6 @@ class Linear(TransformerEngineBaseModule):
On NVIDIA GPUs it is a drop-in replacement for `torch.nn.Linear`.
.. warning::
Argument :attr:`skip_weight_param_allocation` is deprecated and will
be fully removed in the next release (v1.0.0).
Parameters
----------
in_features : int
...
...
@@ -558,7 +553,6 @@ class Linear(TransformerEngineBaseModule):
return_bias
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
parallel_mode
:
Optional
[
str
]
=
None
,
skip_weight_param_allocation
:
bool
=
False
,
parameters_split
:
Optional
[
Union
[
Tuple
[
str
,
...],
Dict
[
str
,
int
]]]
=
None
,
ub_split_rs
:
bool
=
False
,
ub_split_ag
:
bool
=
False
,
...
...
@@ -568,14 +562,6 @@ class Linear(TransformerEngineBaseModule):
)
->
None
:
super
().
__init__
()
if
skip_weight_param_allocation
:
warnings
.
warn
(
"Argument `skip_weight_param_allocation` is deprecated and"
"will be fully removed in the next release (v1.0.0). It has ignored"
"starting from v0.11."
,
category
=
DeprecationWarning
,
)
params_dtype
=
torch
.
get_default_dtype
()
if
params_dtype
is
None
else
params_dtype
self
.
in_features
=
in_features
self
.
out_features
=
out_features
...
...
@@ -736,18 +722,11 @@ class Linear(TransformerEngineBaseModule):
def
forward
(
self
,
inp
:
torch
.
Tensor
,
weight
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
is_first_microbatch
:
Optional
[
bool
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
...]]:
"""
Apply the linear transformation to the input.
.. warning::
Arguments :attr:`weight` and :attr:`bias` are deprecated and will
be fully removed in the next release (v1.0.0).
Parameters
----------
inp : torch.Tensor
...
...
@@ -767,12 +746,6 @@ class Linear(TransformerEngineBaseModule):
produced)
"""
if
weight
is
not
None
or
bias
is
not
None
:
raise
RuntimeError
(
"Arguments `weight` and `bias` are deprecated and "
"will be fully removed in the next release (v1.0.0)."
)
with
self
.
prepare_forward
(
inp
,
is_first_microbatch
)
as
inp
:
bias_tensor
=
(
self
.
bias
if
self
.
parameters_split
is
None
...
...
transformer_engine/pytorch/transformer.py
View file @
daa5e184
...
...
@@ -68,11 +68,6 @@ class TransformerLayer(torch.nn.Module):
TransformerLayer is made up of an attention block and a feedforward network (MLP).
This standard layer is based on the paper "Attention Is All You Need".
.. warning::
Arguments :attr:`attention_softmax_in_fp32` and :attr:`apply_query_key_layer_scaling`
are deprecated and will be fully removed in the next release (v1.0.0).
.. note::
Argument :attr:`attention_mask` will be ignored in the `forward` call when
...
...
@@ -224,8 +219,6 @@ class TransformerLayer(torch.nn.Module):
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
get_rng_state_tracker
:
Optional
[
Callable
]
=
None
,
fuse_wgrad_accumulation
:
bool
=
False
,
apply_query_key_layer_scaling
:
bool
=
False
,
# pylint: disable=unused-argument
attention_softmax_in_fp32
:
bool
=
True
,
# pylint: disable=unused-argument
seq_length
:
Optional
[
int
]
=
None
,
micro_batch_size
:
Optional
[
int
]
=
None
,
sequence_parallel
:
bool
=
False
,
...
...
@@ -245,12 +238,6 @@ class TransformerLayer(torch.nn.Module):
)
->
None
:
super
().
__init__
()
warnings
.
warn
(
"Arguments `attention_softmax_in_fp32` and `apply_query_key_layer_scaling`"
"are deprecated and will be fully removed in the next release (v1.0.0)."
,
category
=
DeprecationWarning
,
)
if
ub_tp_comm_overlap
:
assert
(
tex
.
userbuf_comm_available
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment