Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
4b956b2a
"tests/test_tokenization_bertweet.py" did not exist on "a3c5883f2c9a12360cee0734dfb262f92b912b24"
Commit
4b956b2a
authored
Sep 13, 2019
by
thomwolf
Browse files
add layer_norm_epsilon configuration for transformer xl
parent
b97af8cc
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
65 additions
and
53 deletions
+65
-53
pytorch_transformers/configuration_transfo_xl.py
pytorch_transformers/configuration_transfo_xl.py
+34
-33
pytorch_transformers/convert_pytorch_checkpoint_to_tf2.py
pytorch_transformers/convert_pytorch_checkpoint_to_tf2.py
+3
-3
pytorch_transformers/modeling_tf_pytorch_utils.py
pytorch_transformers/modeling_tf_pytorch_utils.py
+3
-1
pytorch_transformers/modeling_tf_transfo_xl.py
pytorch_transformers/modeling_tf_transfo_xl.py
+14
-8
pytorch_transformers/modeling_transfo_xl.py
pytorch_transformers/modeling_transfo_xl.py
+11
-8
No files found.
pytorch_transformers/configuration_transfo_xl.py
View file @
4b956b2a
...
...
@@ -95,19 +95,12 @@ class TransfoXLConfig(PretrainedConfig):
init_range
=
0.01
,
proj_init_std
=
0.01
,
init_std
=
0.02
,
layer_norm_epsilon
=
1e-5
,
**
kwargs
):
"""Constructs TransfoXLConfig.
"""
super
(
TransfoXLConfig
,
self
).
__init__
(
**
kwargs
)
if
isinstance
(
vocab_size_or_config_json_file
,
str
)
or
(
sys
.
version_info
[
0
]
==
2
and
isinstance
(
vocab_size_or_config_json_file
,
unicode
)):
with
open
(
vocab_size_or_config_json_file
,
"r"
,
encoding
=
'utf-8'
)
as
reader
:
json_config
=
json
.
loads
(
reader
.
read
())
for
key
,
value
in
json_config
.
items
():
self
.
__dict__
[
key
]
=
value
elif
isinstance
(
vocab_size_or_config_json_file
,
int
):
self
.
n_token
=
vocab_size_or_config_json_file
self
.
n_token
=
vocab_size_or_config_json_file
if
isinstance
(
vocab_size_or_config_json_file
,
int
)
else
-
1
self
.
cutoffs
=
[]
self
.
cutoffs
.
extend
(
cutoffs
)
self
.
tie_weight
=
tie_weight
...
...
@@ -138,7 +131,15 @@ class TransfoXLConfig(PretrainedConfig):
self
.
init_range
=
init_range
self
.
proj_init_std
=
proj_init_std
self
.
init_std
=
init_std
else
:
self
.
layer_norm_epsilon
=
layer_norm_epsilon
if
isinstance
(
vocab_size_or_config_json_file
,
str
)
or
(
sys
.
version_info
[
0
]
==
2
and
isinstance
(
vocab_size_or_config_json_file
,
unicode
)):
with
open
(
vocab_size_or_config_json_file
,
"r"
,
encoding
=
'utf-8'
)
as
reader
:
json_config
=
json
.
loads
(
reader
.
read
())
for
key
,
value
in
json_config
.
items
():
self
.
__dict__
[
key
]
=
value
elif
not
isinstance
(
vocab_size_or_config_json_file
,
int
):
raise
ValueError
(
"First argument must be either a vocabulary size (int)"
" or the path to a pretrained model config file (str)"
)
...
...
pytorch_transformers/convert_pytorch_checkpoint_to_tf2.py
View file @
4b956b2a
...
...
@@ -124,7 +124,7 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, compare_with
print
(
" Converting checkpoint {}/{}: {}"
.
format
(
i
,
len
(
aws_config_map
),
shortcut_name
))
print
(
"-"
*
100
)
if
'finetuned'
in
shortcut_name
:
print
(
" Skipping fin
tenu
ed checkpoint "
)
print
(
" Skipping fin
etun
ed checkpoint "
)
continue
config_file
=
cached_path
(
aws_config_map
[
shortcut_name
],
force_download
=
True
)
model_file
=
cached_path
(
aws_model_maps
[
shortcut_name
],
force_download
=
True
)
...
...
pytorch_transformers/modeling_tf_pytorch_utils.py
View file @
4b956b2a
...
...
@@ -91,8 +91,10 @@ def load_pytorch_state_dict_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None
name
=
name
.
split
(
'/'
)
# Convert from TF2.0 '/' separators to PyTorch '.' separators
name
=
name
[
1
:]
# Remove level zero
# When should we transpose the weights
transpose
=
bool
(
name
[
-
1
]
==
'kernel'
or
'emb_projs'
in
name
or
'out_projs'
in
name
)
# Convert standard TF2.0 names in PyTorch names
transpose
=
bool
(
name
[
-
1
]
==
'kernel'
)
if
name
[
-
1
]
==
'kernel'
or
name
[
-
1
]
==
'embeddings'
or
name
[
-
1
]
==
'gamma'
:
name
[
-
1
]
=
'weight'
if
name
[
-
1
]
==
'beta'
:
...
...
pytorch_transformers/modeling_tf_transfo_xl.py
View file @
4b956b2a
...
...
@@ -66,7 +66,7 @@ class TFPositionalEmbedding(tf.keras.layers.Layer):
class
TFPositionwiseFF
(
tf
.
keras
.
layers
.
Layer
):
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
False
,
**
kwargs
):
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
False
,
layer_norm_epsilon
=
1e-5
,
**
kwargs
):
super
(
TFPositionwiseFF
,
self
).
__init__
(
**
kwargs
)
self
.
d_model
=
d_model
...
...
@@ -75,10 +75,10 @@ class TFPositionwiseFF(tf.keras.layers.Layer):
self
.
layer_1
=
tf
.
keras
.
layers
.
Dense
(
d_inner
,
activation
=
tf
.
nn
.
relu
,
name
=
'CoreNet_._0'
)
self
.
drop_1
=
tf
.
keras
.
layers
.
Dropout
(
dropout
)
self
.
layer_2
=
tf
.
keras
.
layers
.
Dense
(
d_model
,
name
=
'CoreNet_._
2
'
)
self
.
layer_2
=
tf
.
keras
.
layers
.
Dense
(
d_model
,
name
=
'CoreNet_._
3
'
)
self
.
drop_2
=
tf
.
keras
.
layers
.
Dropout
(
dropout
)
self
.
layer_norm
=
tf
.
keras
.
layers
.
LayerNormalization
(
epsilon
=
1e-12
,
name
=
'layer_norm'
)
self
.
layer_norm
=
tf
.
keras
.
layers
.
LayerNormalization
(
epsilon
=
layer_norm_epsilon
,
name
=
'layer_norm'
)
self
.
pre_lnorm
=
pre_lnorm
...
...
@@ -109,7 +109,8 @@ class TFPositionwiseFF(tf.keras.layers.Layer):
class
TFRelPartialLearnableMultiHeadAttn
(
tf
.
keras
.
layers
.
Layer
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
dropout
,
dropatt
=
0
,
tgt_len
=
None
,
ext_len
=
None
,
mem_len
=
None
,
pre_lnorm
=
False
,
r_r_bias
=
None
,
r_w_bias
=
None
,
output_attentions
=
False
,
**
kwargs
):
r_r_bias
=
None
,
r_w_bias
=
None
,
output_attentions
=
False
,
layer_norm_epsilon
=
1e-5
,
**
kwargs
):
super
(
TFRelPartialLearnableMultiHeadAttn
,
self
).
__init__
(
**
kwargs
)
self
.
output_attentions
=
output_attentions
...
...
@@ -124,7 +125,7 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
self
.
dropatt
=
tf
.
keras
.
layers
.
Dropout
(
dropatt
)
self
.
o_net
=
tf
.
keras
.
layers
.
Dense
(
d_model
,
use_bias
=
False
,
name
=
'o_net'
)
self
.
layer_norm
=
tf
.
keras
.
layers
.
LayerNormalization
(
epsilon
=
1e-12
,
name
=
'layer_norm'
)
self
.
layer_norm
=
tf
.
keras
.
layers
.
LayerNormalization
(
epsilon
=
layer_norm_epsilon
,
name
=
'layer_norm'
)
self
.
scale
=
1
/
(
d_head
**
0.5
)
...
...
@@ -247,6 +248,7 @@ class TFRelPartialLearnableDecoderLayer(tf.keras.layers.Layer):
r_w_bias
=
None
,
r_r_bias
=
None
,
output_attentions
=
False
,
layer_norm_epsilon
=
1e-5
,
**
kwargs
):
super
(
TFRelPartialLearnableDecoderLayer
,
self
).
__init__
(
**
kwargs
)
...
...
@@ -254,9 +256,12 @@ class TFRelPartialLearnableDecoderLayer(tf.keras.layers.Layer):
d_head
,
dropout
,
tgt_len
=
tgt_len
,
ext_len
=
ext_len
,
mem_len
=
mem_len
,
dropatt
=
dropatt
,
pre_lnorm
=
pre_lnorm
,
r_w_bias
=
r_w_bias
,
r_r_bias
=
r_r_bias
,
output_attentions
=
output_attentions
,
name
=
'dec_attn'
)
output_attentions
=
output_attentions
,
layer_norm_epsilon
=
layer_norm_epsilon
,
name
=
'dec_attn'
)
self
.
pos_ff
=
TFPositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
pre_lnorm
,
name
=
'pos_ff'
)
pre_lnorm
=
pre_lnorm
,
layer_norm_epsilon
=
layer_norm_epsilon
,
name
=
'pos_ff'
)
def
call
(
self
,
inputs
,
training
=
False
):
dec_inp
,
r
,
dec_attn_mask
,
mems
,
head_mask
=
inputs
...
...
@@ -300,7 +305,7 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer):
d_emb_i
=
self
.
d_embed
//
(
self
.
div_val
**
i
)
self
.
emb_projs
.
append
(
self
.
add_weight
(
shape
=
(
d_emb_i
,
self
.
d_proj
),
trainable
=
True
,
name
=
'emb_projs._{}'
.
format
(
i
)))
name
=
'emb_projs
_
._{}'
.
format
(
i
)))
super
(
TFAdaptiveEmbedding
,
self
).
build
(
input_shape
)
def
call
(
self
,
inp
):
...
...
@@ -368,6 +373,7 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
r_w_bias
=
None
if
self
.
untie_r
else
self
.
r_w_bias
,
r_r_bias
=
None
if
self
.
untie_r
else
self
.
r_r_bias
,
output_attentions
=
self
.
output_attentions
,
layer_norm_epsilon
=
config
.
layer_norm_epsilon
,
name
=
'layers_._{}'
.
format
(
i
))
)
else
:
# learnable embeddings and absolute embeddings
...
...
pytorch_transformers/modeling_transfo_xl.py
View file @
4b956b2a
...
...
@@ -194,7 +194,7 @@ class PositionalEmbedding(nn.Module):
class
PositionwiseFF
(
nn
.
Module
):
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
False
):
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
False
,
layer_norm_epsilon
=
1e-5
):
super
(
PositionwiseFF
,
self
).
__init__
()
self
.
d_model
=
d_model
...
...
@@ -208,7 +208,7 @@ class PositionwiseFF(nn.Module):
nn
.
Dropout
(
dropout
),
)
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
,
eps
=
layer_norm_epsilon
)
self
.
pre_lnorm
=
pre_lnorm
...
...
@@ -232,7 +232,8 @@ class PositionwiseFF(nn.Module):
class
RelPartialLearnableMultiHeadAttn
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
dropout
,
dropatt
=
0
,
tgt_len
=
None
,
ext_len
=
None
,
mem_len
=
None
,
pre_lnorm
=
False
,
r_r_bias
=
None
,
r_w_bias
=
None
,
output_attentions
=
False
):
r_r_bias
=
None
,
r_w_bias
=
None
,
output_attentions
=
False
,
layer_norm_epsilon
=
1e-5
):
super
(
RelPartialLearnableMultiHeadAttn
,
self
).
__init__
()
self
.
output_attentions
=
output_attentions
...
...
@@ -247,7 +248,7 @@ class RelPartialLearnableMultiHeadAttn(nn.Module):
self
.
dropatt
=
nn
.
Dropout
(
dropatt
)
self
.
o_net
=
nn
.
Linear
(
n_head
*
d_head
,
d_model
,
bias
=
False
)
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
,
eps
=
layer_norm_epsilon
)
self
.
scale
=
1
/
(
d_head
**
0.5
)
...
...
@@ -359,14 +360,15 @@ class RelPartialLearnableMultiHeadAttn(nn.Module):
class
RelPartialLearnableDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
layer_norm_epsilon
=
1e-5
,
**
kwargs
):
super
(
RelPartialLearnableDecoderLayer
,
self
).
__init__
()
self
.
dec_attn
=
RelPartialLearnableMultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
**
kwargs
)
d_head
,
dropout
,
layer_norm_epsilon
=
layer_norm_epsilon
,
**
kwargs
)
self
.
pos_ff
=
PositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
),
layer_norm_epsilon
=
layer_norm_epsilon
)
def
forward
(
self
,
dec_inp
,
r
,
dec_attn_mask
=
None
,
mems
=
None
,
head_mask
=
None
):
...
...
@@ -613,7 +615,8 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
dropatt
=
config
.
dropatt
,
pre_lnorm
=
config
.
pre_lnorm
,
r_w_bias
=
None
if
config
.
untie_r
else
self
.
r_w_bias
,
r_r_bias
=
None
if
config
.
untie_r
else
self
.
r_r_bias
,
output_attentions
=
self
.
output_attentions
)
output_attentions
=
self
.
output_attentions
,
layer_norm_epsilon
=
config
.
layer_norm_epsilon
)
)
else
:
# learnable embeddings and absolute embeddings are not used in our pretrained checkpoints
raise
NotImplementedError
# Removed them to avoid maintaining dead code
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment