Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
47ca0eaa
Unverified
Commit
47ca0eaa
authored
Jan 04, 2021
by
Stas Bekman
Committed by
GitHub
Jan 04, 2021
Browse files
replace apex.normalization.FusedLayerNorm with torch.nn.LayerNorm (#9386)
parent
75ff5305
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
19 additions
and
48 deletions
+19
-48
src/transformers/models/bart/modeling_bart.py
src/transformers/models/bart/modeling_bart.py
+10
-20
src/transformers/models/fsmt/modeling_fsmt.py
src/transformers/models/fsmt/modeling_fsmt.py
+1
-11
src/transformers/models/prophetnet/modeling_prophetnet.py
src/transformers/models/prophetnet/modeling_prophetnet.py
+8
-17
No files found.
src/transformers/models/bart/modeling_bart.py
View file @
47ca0eaa
...
@@ -22,7 +22,7 @@ import numpy as np
...
@@ -22,7 +22,7 @@ import numpy as np
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch
import
nn
from
torch
import
nn
from
torch.nn
import
CrossEntropyLoss
from
torch.nn
import
CrossEntropyLoss
,
LayerNorm
from
...activations
import
ACT2FN
from
...activations
import
ACT2FN
from
...file_utils
import
(
from
...file_utils
import
(
...
@@ -109,16 +109,6 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
...
@@ -109,16 +109,6 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
return
inverted_mask
.
masked_fill
(
inverted_mask
.
bool
(),
torch
.
finfo
(
dtype
).
min
)
return
inverted_mask
.
masked_fill
(
inverted_mask
.
bool
(),
torch
.
finfo
(
dtype
).
min
)
def
BartLayerNorm
(
normalized_shape
:
torch
.
Size
,
eps
:
float
=
1e-5
,
elementwise_affine
:
bool
=
True
):
try
:
from
apex.normalization
import
FusedLayerNorm
return
FusedLayerNorm
(
normalized_shape
,
eps
,
elementwise_affine
)
except
ImportError
:
pass
return
torch
.
nn
.
LayerNorm
(
normalized_shape
,
eps
,
elementwise_affine
)
class
BartLearnedPositionalEmbedding
(
nn
.
Embedding
):
class
BartLearnedPositionalEmbedding
(
nn
.
Embedding
):
"""
"""
This module learns positional embeddings up to a fixed maximum size. Padding ids are ignored by either offsetting
This module learns positional embeddings up to a fixed maximum size. Padding ids are ignored by either offsetting
...
@@ -321,13 +311,13 @@ class BartEncoderLayer(nn.Module):
...
@@ -321,13 +311,13 @@ class BartEncoderLayer(nn.Module):
dropout
=
config
.
attention_dropout
,
dropout
=
config
.
attention_dropout
,
)
)
self
.
normalize_before
=
config
.
normalize_before
self
.
normalize_before
=
config
.
normalize_before
self
.
self_attn_layer_norm
=
Bart
LayerNorm
(
self
.
embed_dim
)
self
.
self_attn_layer_norm
=
LayerNorm
(
self
.
embed_dim
)
self
.
dropout
=
config
.
dropout
self
.
dropout
=
config
.
dropout
self
.
activation_fn
=
ACT2FN
[
config
.
activation_function
]
self
.
activation_fn
=
ACT2FN
[
config
.
activation_function
]
self
.
activation_dropout
=
config
.
activation_dropout
self
.
activation_dropout
=
config
.
activation_dropout
self
.
fc1
=
nn
.
Linear
(
self
.
embed_dim
,
config
.
encoder_ffn_dim
)
self
.
fc1
=
nn
.
Linear
(
self
.
embed_dim
,
config
.
encoder_ffn_dim
)
self
.
fc2
=
nn
.
Linear
(
config
.
encoder_ffn_dim
,
self
.
embed_dim
)
self
.
fc2
=
nn
.
Linear
(
config
.
encoder_ffn_dim
,
self
.
embed_dim
)
self
.
final_layer_norm
=
Bart
LayerNorm
(
self
.
embed_dim
)
self
.
final_layer_norm
=
LayerNorm
(
self
.
embed_dim
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
torch
.
Tensor
,
output_attentions
:
bool
=
False
):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
torch
.
Tensor
,
output_attentions
:
bool
=
False
):
"""
"""
...
@@ -380,17 +370,17 @@ class BartDecoderLayer(nn.Module):
...
@@ -380,17 +370,17 @@ class BartDecoderLayer(nn.Module):
self
.
activation_dropout
=
config
.
activation_dropout
self
.
activation_dropout
=
config
.
activation_dropout
self
.
normalize_before
=
config
.
normalize_before
self
.
normalize_before
=
config
.
normalize_before
self
.
self_attn_layer_norm
=
Bart
LayerNorm
(
self
.
embed_dim
)
self
.
self_attn_layer_norm
=
LayerNorm
(
self
.
embed_dim
)
self
.
encoder_attn
=
BartAttention
(
self
.
encoder_attn
=
BartAttention
(
self
.
embed_dim
,
self
.
embed_dim
,
config
.
decoder_attention_heads
,
config
.
decoder_attention_heads
,
dropout
=
config
.
attention_dropout
,
dropout
=
config
.
attention_dropout
,
is_decoder
=
True
,
is_decoder
=
True
,
)
)
self
.
encoder_attn_layer_norm
=
Bart
LayerNorm
(
self
.
embed_dim
)
self
.
encoder_attn_layer_norm
=
LayerNorm
(
self
.
embed_dim
)
self
.
fc1
=
nn
.
Linear
(
self
.
embed_dim
,
config
.
decoder_ffn_dim
)
self
.
fc1
=
nn
.
Linear
(
self
.
embed_dim
,
config
.
decoder_ffn_dim
)
self
.
fc2
=
nn
.
Linear
(
config
.
decoder_ffn_dim
,
self
.
embed_dim
)
self
.
fc2
=
nn
.
Linear
(
config
.
decoder_ffn_dim
,
self
.
embed_dim
)
self
.
final_layer_norm
=
Bart
LayerNorm
(
self
.
embed_dim
)
self
.
final_layer_norm
=
LayerNorm
(
self
.
embed_dim
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -672,9 +662,9 @@ class BartEncoder(BartPretrainedModel):
...
@@ -672,9 +662,9 @@ class BartEncoder(BartPretrainedModel):
config
.
extra_pos_embeddings
,
config
.
extra_pos_embeddings
,
)
)
self
.
layers
=
nn
.
ModuleList
([
BartEncoderLayer
(
config
)
for
_
in
range
(
config
.
encoder_layers
)])
self
.
layers
=
nn
.
ModuleList
([
BartEncoderLayer
(
config
)
for
_
in
range
(
config
.
encoder_layers
)])
self
.
layernorm_embedding
=
Bart
LayerNorm
(
embed_dim
)
if
config
.
normalize_embedding
else
nn
.
Identity
()
self
.
layernorm_embedding
=
LayerNorm
(
embed_dim
)
if
config
.
normalize_embedding
else
nn
.
Identity
()
# mbart has one extra layer_norm
# mbart has one extra layer_norm
self
.
layer_norm
=
Bart
LayerNorm
(
config
.
d_model
)
if
config
.
add_final_layer_norm
else
None
self
.
layer_norm
=
LayerNorm
(
config
.
d_model
)
if
config
.
add_final_layer_norm
else
None
self
.
init_weights
()
self
.
init_weights
()
...
@@ -812,8 +802,8 @@ class BartDecoder(BartPretrainedModel):
...
@@ -812,8 +802,8 @@ class BartDecoder(BartPretrainedModel):
config
.
extra_pos_embeddings
,
config
.
extra_pos_embeddings
,
)
)
self
.
layers
=
nn
.
ModuleList
([
BartDecoderLayer
(
config
)
for
_
in
range
(
config
.
decoder_layers
)])
self
.
layers
=
nn
.
ModuleList
([
BartDecoderLayer
(
config
)
for
_
in
range
(
config
.
decoder_layers
)])
self
.
layernorm_embedding
=
Bart
LayerNorm
(
config
.
d_model
)
if
config
.
normalize_embedding
else
nn
.
Identity
()
self
.
layernorm_embedding
=
LayerNorm
(
config
.
d_model
)
if
config
.
normalize_embedding
else
nn
.
Identity
()
self
.
layer_norm
=
Bart
LayerNorm
(
config
.
d_model
)
if
config
.
add_final_layer_norm
else
None
self
.
layer_norm
=
LayerNorm
(
config
.
d_model
)
if
config
.
add_final_layer_norm
else
None
self
.
init_weights
()
self
.
init_weights
()
...
...
src/transformers/models/fsmt/modeling_fsmt.py
View file @
47ca0eaa
...
@@ -34,7 +34,7 @@ from typing import Any, Dict, List, Optional, Tuple
...
@@ -34,7 +34,7 @@ from typing import Any, Dict, List, Optional, Tuple
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch
import
Tensor
,
nn
from
torch
import
Tensor
,
nn
from
torch.nn
import
CrossEntropyLoss
from
torch.nn
import
CrossEntropyLoss
,
LayerNorm
from
...activations
import
ACT2FN
from
...activations
import
ACT2FN
from
...file_utils
import
(
from
...file_utils
import
(
...
@@ -264,16 +264,6 @@ FSMT_INPUTS_DOCSTRING = r"""
...
@@ -264,16 +264,6 @@ FSMT_INPUTS_DOCSTRING = r"""
"""
"""
have_fused_layer_norm
=
False
try
:
from
apex.normalization
import
FusedLayerNorm
have_fused_layer_norm
=
True
except
ImportError
:
pass
LayerNorm
=
FusedLayerNorm
if
have_fused_layer_norm
else
torch
.
nn
.
LayerNorm
def
invert_mask
(
attention_mask
):
def
invert_mask
(
attention_mask
):
"""Turns 1->0, 0->1, False->True, True-> False"""
"""Turns 1->0, 0->1, False->True, True-> False"""
assert
attention_mask
.
dim
()
==
2
assert
attention_mask
.
dim
()
==
2
...
...
src/transformers/models/prophetnet/modeling_prophetnet.py
View file @
47ca0eaa
...
@@ -23,6 +23,7 @@ from typing import Dict, Optional, Tuple
...
@@ -23,6 +23,7 @@ from typing import Dict, Optional, Tuple
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch
import
Tensor
,
nn
from
torch
import
Tensor
,
nn
from
torch.nn
import
LayerNorm
from
...activations
import
ACT2FN
from
...activations
import
ACT2FN
from
...file_utils
import
(
from
...file_utils
import
(
...
@@ -510,16 +511,6 @@ class ProphetNetDecoderLMOutput(ModelOutput):
...
@@ -510,16 +511,6 @@ class ProphetNetDecoderLMOutput(ModelOutput):
cross_attentions
:
Optional
[
Tuple
[
torch
.
FloatTensor
]]
=
None
cross_attentions
:
Optional
[
Tuple
[
torch
.
FloatTensor
]]
=
None
def
ProphetNetLayerNorm
(
normalized_shape
,
eps
=
1e-5
,
elementwise_affine
=
True
):
try
:
from
apex.normalization
import
FusedLayerNorm
return
FusedLayerNorm
(
normalized_shape
,
eps
,
elementwise_affine
)
except
ImportError
:
pass
return
torch
.
nn
.
LayerNorm
(
normalized_shape
,
eps
,
elementwise_affine
)
class
ProphetNetPreTrainedModel
(
PreTrainedModel
):
class
ProphetNetPreTrainedModel
(
PreTrainedModel
):
config_class
=
ProphetNetConfig
config_class
=
ProphetNetConfig
base_model_prefix
=
"prophetnet"
base_model_prefix
=
"prophetnet"
...
@@ -1044,11 +1035,11 @@ class ProphetNetEncoderLayer(nn.Module):
...
@@ -1044,11 +1035,11 @@ class ProphetNetEncoderLayer(nn.Module):
super
().
__init__
()
super
().
__init__
()
# 1st residual block
# 1st residual block
self
.
self_attn
=
ProphetNetSelfAttention
(
config
,
config
.
num_encoder_attention_heads
)
self
.
self_attn
=
ProphetNetSelfAttention
(
config
,
config
.
num_encoder_attention_heads
)
self
.
self_attn_layer_norm
=
ProphetNet
LayerNorm
(
config
.
hidden_size
)
self
.
self_attn_layer_norm
=
LayerNorm
(
config
.
hidden_size
)
# 2nd residual block
# 2nd residual block
self
.
feed_forward
=
ProhpetNetFeedForward
(
config
,
config
.
encoder_ffn_dim
)
self
.
feed_forward
=
ProhpetNetFeedForward
(
config
,
config
.
encoder_ffn_dim
)
self
.
feed_forward_layer_norm
=
ProphetNet
LayerNorm
(
config
.
hidden_size
)
self
.
feed_forward_layer_norm
=
LayerNorm
(
config
.
hidden_size
)
def
forward
(
self
,
hidden_states
,
attention_mask
):
def
forward
(
self
,
hidden_states
,
attention_mask
):
# 1st residual block
# 1st residual block
...
@@ -1073,16 +1064,16 @@ class ProphetNetDecoderLayer(nn.Module):
...
@@ -1073,16 +1064,16 @@ class ProphetNetDecoderLayer(nn.Module):
super
().
__init__
()
super
().
__init__
()
# 1st residual block
# 1st residual block
self
.
self_attn
=
ProphetNetNgramProphetNetSelfAttention
(
config
)
self
.
self_attn
=
ProphetNetNgramProphetNetSelfAttention
(
config
)
self
.
self_attn_layer_norm
=
ProphetNet
LayerNorm
(
config
.
hidden_size
)
self
.
self_attn_layer_norm
=
LayerNorm
(
config
.
hidden_size
)
# 2nd residual block
# 2nd residual block
if
config
.
add_cross_attention
:
if
config
.
add_cross_attention
:
self
.
cross_attn
=
ProphetNetSelfAttention
(
config
,
config
.
num_decoder_attention_heads
)
self
.
cross_attn
=
ProphetNetSelfAttention
(
config
,
config
.
num_decoder_attention_heads
)
self
.
cross_attn_layer_norm
=
ProphetNet
LayerNorm
(
config
.
hidden_size
)
self
.
cross_attn_layer_norm
=
LayerNorm
(
config
.
hidden_size
)
# 3rd residual block
# 3rd residual block
self
.
feed_forward
=
ProhpetNetFeedForward
(
config
,
config
.
decoder_ffn_dim
)
self
.
feed_forward
=
ProhpetNetFeedForward
(
config
,
config
.
decoder_ffn_dim
)
self
.
feed_forward_layer_norm
=
ProphetNet
LayerNorm
(
config
.
hidden_size
)
self
.
feed_forward_layer_norm
=
LayerNorm
(
config
.
hidden_size
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -1154,7 +1145,7 @@ class ProphetNetEncoder(ProphetNetPreTrainedModel):
...
@@ -1154,7 +1145,7 @@ class ProphetNetEncoder(ProphetNetPreTrainedModel):
else
nn
.
Embedding
(
config
.
vocab_size
,
config
.
hidden_size
,
padding_idx
=
config
.
pad_token_id
)
else
nn
.
Embedding
(
config
.
vocab_size
,
config
.
hidden_size
,
padding_idx
=
config
.
pad_token_id
)
)
)
self
.
position_embeddings
=
ProhpetNetPositionalEmbeddings
(
config
)
self
.
position_embeddings
=
ProhpetNetPositionalEmbeddings
(
config
)
self
.
embeddings_layer_norm
=
ProphetNet
LayerNorm
(
config
.
hidden_size
)
self
.
embeddings_layer_norm
=
LayerNorm
(
config
.
hidden_size
)
self
.
layers
=
nn
.
ModuleList
([
ProphetNetEncoderLayer
(
config
)
for
_
in
range
(
config
.
num_encoder_layers
)])
self
.
layers
=
nn
.
ModuleList
([
ProphetNetEncoderLayer
(
config
)
for
_
in
range
(
config
.
num_encoder_layers
)])
...
@@ -1274,7 +1265,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
...
@@ -1274,7 +1265,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
self
.
ngram_embeddings
=
nn
.
Embedding
(
self
.
ngram
,
config
.
hidden_size
,
None
)
self
.
ngram_embeddings
=
nn
.
Embedding
(
self
.
ngram
,
config
.
hidden_size
,
None
)
self
.
layers
=
nn
.
ModuleList
([
ProphetNetDecoderLayer
(
config
)
for
_
in
range
(
config
.
num_decoder_layers
)])
self
.
layers
=
nn
.
ModuleList
([
ProphetNetDecoderLayer
(
config
)
for
_
in
range
(
config
.
num_decoder_layers
)])
self
.
embeddings_layer_norm
=
ProphetNet
LayerNorm
(
config
.
hidden_size
)
self
.
embeddings_layer_norm
=
LayerNorm
(
config
.
hidden_size
)
self
.
init_weights
()
self
.
init_weights
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment