Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
646711e1
"vscode:/vscode.git/clone" did not exist on "3854bd0fa02bf9afb5249d9eaaa6c827442105e0"
Commit
646711e1
authored
Sep 11, 2019
by
thomwolf
Browse files
standardize scopes names - add conversion methods
parent
4356f791
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
177 additions
and
279 deletions
+177
-279
pytorch_transformers/convert_pytorch_checkpoint_to_tf2.py
pytorch_transformers/convert_pytorch_checkpoint_to_tf2.py
+5
-3
pytorch_transformers/convert_xlm_original_pytorch_checkpoint_to_pytorch.py
...ers/convert_xlm_original_pytorch_checkpoint_to_pytorch.py
+10
-2
pytorch_transformers/modeling_tf_bert.py
pytorch_transformers/modeling_tf_bert.py
+14
-72
pytorch_transformers/modeling_tf_gpt2.py
pytorch_transformers/modeling_tf_gpt2.py
+6
-70
pytorch_transformers/modeling_tf_pytorch_utils.py
pytorch_transformers/modeling_tf_pytorch_utils.py
+122
-0
pytorch_transformers/modeling_tf_xlm.py
pytorch_transformers/modeling_tf_xlm.py
+16
-69
pytorch_transformers/modeling_tf_xlnet.py
pytorch_transformers/modeling_tf_xlnet.py
+4
-63
No files found.
pytorch_transformers/convert_pytorch_checkpoint_to_tf2.py
View file @
646711e1
...
@@ -25,12 +25,13 @@ from pytorch_transformers import is_torch_available
...
@@ -25,12 +25,13 @@ from pytorch_transformers import is_torch_available
from
pytorch_transformers
import
(
BertConfig
,
TFBertForPreTraining
,
load_bert_pt_weights_in_tf2
,
from
pytorch_transformers
import
(
BertConfig
,
TFBertForPreTraining
,
load_bert_pt_weights_in_tf2
,
GPT2Config
,
TFGPT2LMHeadModel
,
load_gpt2_pt_weights_in_tf2
,
GPT2Config
,
TFGPT2LMHeadModel
,
load_gpt2_pt_weights_in_tf2
,
XLNetConfig
,
TFXLNetLMHeadModel
,
load_xlnet_pt_weights_in_tf2
)
XLNetConfig
,
TFXLNetLMHeadModel
,
load_xlnet_pt_weights_in_tf2
,
XLMConfig
,
TFXLMWithLMHeadModel
,
load_xlm_pt_weights_in_tf2
,)
if
is_torch_available
():
if
is_torch_available
():
import
torch
import
torch
import
numpy
as
np
import
numpy
as
np
from
pytorch_transformers
import
BertForPreTraining
,
GPT2LMHeadModel
,
XLNetLMHeadModel
from
pytorch_transformers
import
BertForPreTraining
,
GPT2LMHeadModel
,
XLNetLMHeadModel
,
XLMWithLMHeadModel
else
:
else
:
BertForPreTraining
,
GPT2LMHeadModel
=
None
,
None
BertForPreTraining
,
GPT2LMHeadModel
=
None
,
None
...
@@ -42,6 +43,7 @@ MODEL_CLASSES = {
...
@@ -42,6 +43,7 @@ MODEL_CLASSES = {
'bert'
:
(
BertConfig
,
TFBertForPreTraining
,
load_bert_pt_weights_in_tf2
,
BertForPreTraining
),
'bert'
:
(
BertConfig
,
TFBertForPreTraining
,
load_bert_pt_weights_in_tf2
,
BertForPreTraining
),
'gpt2'
:
(
GPT2Config
,
TFGPT2LMHeadModel
,
load_gpt2_pt_weights_in_tf2
,
GPT2LMHeadModel
),
'gpt2'
:
(
GPT2Config
,
TFGPT2LMHeadModel
,
load_gpt2_pt_weights_in_tf2
,
GPT2LMHeadModel
),
'xlnet'
:
(
XLNetConfig
,
TFXLNetLMHeadModel
,
load_xlnet_pt_weights_in_tf2
,
XLNetLMHeadModel
),
'xlnet'
:
(
XLNetConfig
,
TFXLNetLMHeadModel
,
load_xlnet_pt_weights_in_tf2
,
XLNetLMHeadModel
),
'xlm'
:
(
XLMConfig
,
TFXLMWithLMHeadModel
,
load_xlm_pt_weights_in_tf2
,
XLMWithLMHeadModel
),
}
}
def
convert_pt_checkpoint_to_tf
(
model_type
,
pytorch_checkpoint_path
,
config_file
,
tf_dump_path
,
compare_with_pt_model
=
False
):
def
convert_pt_checkpoint_to_tf
(
model_type
,
pytorch_checkpoint_path
,
config_file
,
tf_dump_path
,
compare_with_pt_model
=
False
):
...
@@ -58,7 +60,7 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
...
@@ -58,7 +60,7 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
tf_model
=
model_class
(
config
)
tf_model
=
model_class
(
config
)
# Load weights from tf checkpoint
# Load weights from tf checkpoint
tf_model
=
loading_fct
(
tf_model
,
config
,
pytorch_checkpoint_path
)
tf_model
=
loading_fct
(
tf_model
,
pytorch_checkpoint_path
)
if
compare_with_pt_model
:
if
compare_with_pt_model
:
inputs_list
=
[[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]]
inputs_list
=
[[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]]
...
...
pytorch_transformers/convert_xlm_original_pytorch_checkpoint_to_pytorch.py
View file @
646711e1
...
@@ -33,7 +33,15 @@ def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_p
...
@@ -33,7 +33,15 @@ def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_p
# Load checkpoint
# Load checkpoint
chkpt
=
torch
.
load
(
xlm_checkpoint_path
,
map_location
=
'cpu'
)
chkpt
=
torch
.
load
(
xlm_checkpoint_path
,
map_location
=
'cpu'
)
model
=
chkpt
[
'model'
]
state_dict
=
chkpt
[
'model'
]
# We have the base model one level deeper than the original XLM repository
two_levels_state_dict
=
{}
for
k
,
v
in
state_dict
.
items
():
if
'pred_layer'
in
k
:
two_levels_state_dict
[
k
]
=
v
else
:
two_levels_state_dict
[
'transformer.'
+
k
]
=
v
config
=
chkpt
[
'params'
]
config
=
chkpt
[
'params'
]
config
=
dict
((
n
,
v
)
for
n
,
v
in
config
.
items
()
if
not
isinstance
(
v
,
(
torch
.
FloatTensor
,
numpy
.
ndarray
)))
config
=
dict
((
n
,
v
)
for
n
,
v
in
config
.
items
()
if
not
isinstance
(
v
,
(
torch
.
FloatTensor
,
numpy
.
ndarray
)))
...
@@ -47,7 +55,7 @@ def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_p
...
@@ -47,7 +55,7 @@ def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_p
pytorch_vocab_dump_path
=
pytorch_dump_folder_path
+
'/'
+
VOCAB_FILES_NAMES
[
'vocab_file'
]
pytorch_vocab_dump_path
=
pytorch_dump_folder_path
+
'/'
+
VOCAB_FILES_NAMES
[
'vocab_file'
]
print
(
"Save PyTorch model to {}"
.
format
(
pytorch_weights_dump_path
))
print
(
"Save PyTorch model to {}"
.
format
(
pytorch_weights_dump_path
))
torch
.
save
(
model
,
pytorch_weights_dump_path
)
torch
.
save
(
two_levels_state_dict
,
pytorch_weights_dump_path
)
print
(
"Save configuration file to {}"
.
format
(
pytorch_config_dump_path
))
print
(
"Save configuration file to {}"
.
format
(
pytorch_config_dump_path
))
with
open
(
pytorch_config_dump_path
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
with
open
(
pytorch_config_dump_path
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
...
...
pytorch_transformers/modeling_tf_bert.py
View file @
646711e1
...
@@ -30,6 +30,7 @@ import tensorflow as tf
...
@@ -30,6 +30,7 @@ import tensorflow as tf
from
.configuration_bert
import
BertConfig
from
.configuration_bert
import
BertConfig
from
.modeling_tf_utils
import
TFPreTrainedModel
from
.modeling_tf_utils
import
TFPreTrainedModel
from
.file_utils
import
add_start_docstrings
from
.file_utils
import
add_start_docstrings
from
.modeling_tf_pytorch_utils
import
load_pytorch_checkpoint_in_tf2_model
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -51,71 +52,12 @@ TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
...
@@ -51,71 +52,12 @@ TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
}
}
def
load_bert_pt_weights_in_tf2
(
tf_model
,
config
,
pytorch_checkpoint_path
):
def
load_bert_pt_weights_in_tf2
(
tf_model
,
pytorch_checkpoint_path
):
""" Load pytorch checkpoints in a TF 2.0 model and save it using HDF5 format
# build the network
We use HDF5 to easily do transfer learning
(see https://github.com/tensorflow/tensorflow/blob/ee16fcac960ae660e0e4496658a366e2f745e1f0/tensorflow/python/keras/engine/network.py#L1352-L1357).
"""
try
:
import
re
import
torch
import
numpy
from
tensorflow.python.keras
import
backend
as
K
except
ImportError
:
logger
.
error
(
"Loading a PyTorch model in TensorFlow, requires PyTorch to be installed. Please see "
"https://pytorch.org/ for installation instructions."
)
raise
pt_path
=
os
.
path
.
abspath
(
pytorch_checkpoint_path
)
logger
.
info
(
"Loading PyTorch weights from {}"
.
format
(
pt_path
))
# Load pytorch model
state_dict
=
torch
.
load
(
pt_path
,
map_location
=
'cpu'
)
inputs_list
=
[[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]]
inputs_list
=
[[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]]
tf_inputs
=
tf
.
constant
(
inputs_list
)
tf_inputs
=
tf
.
constant
(
inputs_list
)
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
# build the network
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
return
load_pytorch_checkpoint_in_tf2_model
(
tf_model
,
pytorch_checkpoint_path
)
symbolic_weights
=
tf_model
.
trainable_weights
+
tf_model
.
non_trainable_weights
weight_value_tuples
=
[]
all_pytorch_weights
=
set
(
list
(
state_dict
.
keys
()))
for
symbolic_weight
in
symbolic_weights
:
name
=
symbolic_weight
.
name
name
=
name
.
replace
(
'cls_mlm'
,
'cls'
)
# We had to split this layer in two in the TF model to be
name
=
name
.
replace
(
'cls_nsp'
,
'cls'
)
# able to do transfer learning (Keras only allow to remove full layers)
name
=
name
.
replace
(
':0'
,
''
)
name
=
name
.
replace
(
'__'
,
'/'
)
name
=
name
.
split
(
'/'
)
name
=
name
[
1
:]
transpose
=
bool
(
name
[
-
1
]
==
'kernel'
)
if
name
[
-
1
]
==
'kernel'
or
name
[
-
1
]
==
'embeddings'
:
name
[
-
1
]
=
'weight'
name
=
'.'
.
join
(
name
)
assert
name
in
state_dict
,
"{} not found in PyTorch model"
.
format
(
name
)
array
=
state_dict
[
name
].
numpy
()
if
transpose
:
array
=
numpy
.
transpose
(
array
)
try
:
assert
list
(
symbolic_weight
.
shape
)
==
list
(
array
.
shape
)
except
AssertionError
as
e
:
e
.
args
+=
(
symbolic_weight
.
shape
,
array
.
shape
)
raise
e
logger
.
info
(
"Initialize TF weight {}"
.
format
(
symbolic_weight
.
name
))
weight_value_tuples
.
append
((
symbolic_weight
,
array
))
all_pytorch_weights
.
discard
(
name
)
K
.
batch_set_value
(
weight_value_tuples
)
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
# Make sure restore ops are run
logger
.
info
(
"Weights or buffers not loaded from PyTorch model: {}"
.
format
(
all_pytorch_weights
))
return
tf_model
def
gelu
(
x
):
def
gelu
(
x
):
...
@@ -391,7 +333,7 @@ class TFBertEncoder(tf.keras.layers.Layer):
...
@@ -391,7 +333,7 @@ class TFBertEncoder(tf.keras.layers.Layer):
super
(
TFBertEncoder
,
self
).
__init__
(
**
kwargs
)
super
(
TFBertEncoder
,
self
).
__init__
(
**
kwargs
)
self
.
output_attentions
=
config
.
output_attentions
self
.
output_attentions
=
config
.
output_attentions
self
.
output_hidden_states
=
config
.
output_hidden_states
self
.
output_hidden_states
=
config
.
output_hidden_states
self
.
layer
=
[
TFBertLayer
(
config
,
name
=
'layer__{}'
.
format
(
i
))
for
i
in
range
(
config
.
num_hidden_layers
)]
self
.
layer
=
[
TFBertLayer
(
config
,
name
=
'layer_
.
_{}'
.
format
(
i
))
for
i
in
range
(
config
.
num_hidden_layers
)]
def
call
(
self
,
inputs
,
training
=
False
):
def
call
(
self
,
inputs
,
training
=
False
):
hidden_states
,
attention_mask
,
head_mask
=
inputs
hidden_states
,
attention_mask
,
head_mask
=
inputs
...
@@ -730,15 +672,15 @@ class TFBertForPreTraining(TFBertPreTrainedModel):
...
@@ -730,15 +672,15 @@ class TFBertForPreTraining(TFBertPreTrainedModel):
super
(
TFBertForPreTraining
,
self
).
__init__
(
config
,
*
inputs
,
**
kwargs
)
super
(
TFBertForPreTraining
,
self
).
__init__
(
config
,
*
inputs
,
**
kwargs
)
self
.
bert
=
TFBertMainLayer
(
config
,
name
=
'bert'
)
self
.
bert
=
TFBertMainLayer
(
config
,
name
=
'bert'
)
self
.
cls_
nsp
=
TFBertNSPHead
(
config
,
name
=
'
cls_nsp
'
)
self
.
nsp
=
TFBertNSPHead
(
config
,
name
=
'
nsp___cls
'
)
self
.
cls_
mlm
=
TFBertMLMHead
(
config
,
self
.
bert
.
embeddings
,
name
=
'
cls_mlm
'
)
self
.
mlm
=
TFBertMLMHead
(
config
,
self
.
bert
.
embeddings
,
name
=
'
mlm___cls
'
)
def
call
(
self
,
inputs
,
training
=
False
):
def
call
(
self
,
inputs
,
training
=
False
):
outputs
=
self
.
bert
(
inputs
,
training
=
training
)
outputs
=
self
.
bert
(
inputs
,
training
=
training
)
sequence_output
,
pooled_output
=
outputs
[:
2
]
sequence_output
,
pooled_output
=
outputs
[:
2
]
prediction_scores
=
self
.
cls_
mlm
(
sequence_output
,
training
=
training
)
prediction_scores
=
self
.
mlm
(
sequence_output
,
training
=
training
)
seq_relationship_score
=
self
.
cls_
nsp
(
pooled_output
)
seq_relationship_score
=
self
.
nsp
(
pooled_output
)
outputs
=
(
prediction_scores
,
seq_relationship_score
,)
+
outputs
[
2
:]
# add hidden states and attention if they are here
outputs
=
(
prediction_scores
,
seq_relationship_score
,)
+
outputs
[
2
:]
# add hidden states and attention if they are here
...
@@ -773,13 +715,13 @@ class TFBertForMaskedLM(TFBertPreTrainedModel):
...
@@ -773,13 +715,13 @@ class TFBertForMaskedLM(TFBertPreTrainedModel):
super
(
TFBertForMaskedLM
,
self
).
__init__
(
config
,
*
inputs
,
**
kwargs
)
super
(
TFBertForMaskedLM
,
self
).
__init__
(
config
,
*
inputs
,
**
kwargs
)
self
.
bert
=
TFBertMainLayer
(
config
,
name
=
'bert'
)
self
.
bert
=
TFBertMainLayer
(
config
,
name
=
'bert'
)
self
.
cls_
mlm
=
TFBertMLMHead
(
config
,
self
.
bert
.
embeddings
,
name
=
'
cls_mlm
'
)
self
.
mlm
=
TFBertMLMHead
(
config
,
self
.
bert
.
embeddings
,
name
=
'
mlm___cls
'
)
def
call
(
self
,
inputs
,
training
=
False
):
def
call
(
self
,
inputs
,
training
=
False
):
outputs
=
self
.
bert
(
inputs
,
training
=
training
)
outputs
=
self
.
bert
(
inputs
,
training
=
training
)
sequence_output
=
outputs
[
0
]
sequence_output
=
outputs
[
0
]
prediction_scores
=
self
.
cls_
mlm
(
sequence_output
,
training
=
training
)
prediction_scores
=
self
.
mlm
(
sequence_output
,
training
=
training
)
outputs
=
(
prediction_scores
,)
+
outputs
[
2
:]
# Add hidden states and attention if they are here
outputs
=
(
prediction_scores
,)
+
outputs
[
2
:]
# Add hidden states and attention if they are here
...
@@ -816,13 +758,13 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel):
...
@@ -816,13 +758,13 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel):
super
(
TFBertForNextSentencePrediction
,
self
).
__init__
(
config
,
*
inputs
,
**
kwargs
)
super
(
TFBertForNextSentencePrediction
,
self
).
__init__
(
config
,
*
inputs
,
**
kwargs
)
self
.
bert
=
TFBertMainLayer
(
config
,
name
=
'bert'
)
self
.
bert
=
TFBertMainLayer
(
config
,
name
=
'bert'
)
self
.
cls_
nsp
=
TFBertNSPHead
(
config
,
name
=
'
cls_nsp
'
)
self
.
nsp
=
TFBertNSPHead
(
config
,
name
=
'
nsp___cls
'
)
def
call
(
self
,
inputs
,
training
=
False
):
def
call
(
self
,
inputs
,
training
=
False
):
outputs
=
self
.
bert
(
inputs
,
training
=
training
)
outputs
=
self
.
bert
(
inputs
,
training
=
training
)
pooled_output
=
outputs
[
1
]
pooled_output
=
outputs
[
1
]
seq_relationship_score
=
self
.
cls_
nsp
(
pooled_output
)
seq_relationship_score
=
self
.
nsp
(
pooled_output
)
outputs
=
(
seq_relationship_score
,)
+
outputs
[
2
:]
# add hidden states and attention if they are here
outputs
=
(
seq_relationship_score
,)
+
outputs
[
2
:]
# add hidden states and attention if they are here
...
...
pytorch_transformers/modeling_tf_gpt2.py
View file @
646711e1
...
@@ -32,6 +32,7 @@ from .modeling_tf_utils import (TFPreTrainedModel, TFConv1D, TFSharedEmbeddings,
...
@@ -32,6 +32,7 @@ from .modeling_tf_utils import (TFPreTrainedModel, TFConv1D, TFSharedEmbeddings,
TFSequenceSummary
,
shape_list
)
TFSequenceSummary
,
shape_list
)
from
.configuration_gpt2
import
GPT2Config
from
.configuration_gpt2
import
GPT2Config
from
.file_utils
import
add_start_docstrings
from
.file_utils
import
add_start_docstrings
from
.modeling_tf_pytorch_utils
import
load_pytorch_checkpoint_in_tf2_model
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -40,77 +41,12 @@ TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models
...
@@ -40,77 +41,12 @@ TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models
"gpt2-large"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-tf_model.h5"
}
"gpt2-large"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-tf_model.h5"
}
def
load_gpt2_pt_weights_in_tf2
(
tf_model
,
config
,
pytorch_checkpoint_path
):
def
load_gpt2_pt_weights_in_tf2
(
tf_model
,
pytorch_checkpoint_path
):
""" Load pytorch checkpoints in a TF 2.0 model and save it using HDF5 format
# build the network
We use HDF5 to easily do transfer learning
(see https://github.com/tensorflow/tensorflow/blob/ee16fcac960ae660e0e4496658a366e2f745e1f0/tensorflow/python/keras/engine/network.py#L1352-L1357).
"""
try
:
import
re
import
torch
import
numpy
from
tensorflow.python.keras
import
backend
as
K
except
ImportError
:
logger
.
error
(
"Loading a PyTorch model in TensorFlow, requires PyTorch to be installed. Please see "
"https://pytorch.org/ for installation instructions."
)
raise
pt_path
=
os
.
path
.
abspath
(
pytorch_checkpoint_path
)
logger
.
info
(
"Loading PyTorch weights from {}"
.
format
(
pt_path
))
# Load pytorch model
state_dict
=
torch
.
load
(
pt_path
,
map_location
=
'cpu'
)
inputs_list
=
[[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]]
inputs_list
=
[[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]]
tf_inputs
=
tf
.
constant
(
inputs_list
)
tf_inputs
=
tf
.
constant
(
inputs_list
)
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
# build the network
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
return
load_pytorch_checkpoint_in_tf2_model
(
tf_model
,
pytorch_checkpoint_path
)
symbolic_weights
=
tf_model
.
trainable_weights
+
tf_model
.
non_trainable_weights
weight_value_tuples
=
[]
all_pytorch_weights
=
set
(
list
(
state_dict
.
keys
()))
for
symbolic_weight
in
symbolic_weights
:
name
=
symbolic_weight
.
name
name
=
name
.
replace
(
':0'
,
''
)
name
=
name
.
replace
(
'__'
,
'/'
)
name
=
name
.
split
(
'/'
)
name
=
name
[
2
:]
transpose
=
bool
(
name
[
-
1
]
==
'kernel'
)
if
name
[
-
1
]
==
'kernel'
or
name
[
-
1
]
==
'embeddings'
or
name
[
-
1
]
==
'gamma'
:
name
[
-
1
]
=
'weight'
if
name
[
-
1
]
==
'beta'
:
name
[
-
1
]
=
'bias'
name
=
'.'
.
join
(
name
)
assert
name
in
state_dict
,
"Weight {} not in PyTorch model"
.
format
(
name
)
array
=
state_dict
[
name
].
numpy
()
if
transpose
:
array
=
numpy
.
transpose
(
array
)
if
len
(
symbolic_weight
.
shape
)
>
len
(
array
.
shape
):
array
=
array
[
None
,
...]
if
len
(
symbolic_weight
.
shape
)
<
len
(
array
.
shape
):
array
=
np
.
squeeze
(
array
)
try
:
assert
list
(
symbolic_weight
.
shape
)
==
list
(
array
.
shape
)
except
AssertionError
as
e
:
e
.
args
+=
(
symbolic_weight
.
shape
,
array
.
shape
)
raise
e
logger
.
info
(
"Initialize TF weight {}"
.
format
(
symbolic_weight
.
name
))
weight_value_tuples
.
append
((
symbolic_weight
,
array
))
all_pytorch_weights
.
discard
(
name
)
K
.
batch_set_value
(
weight_value_tuples
)
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
# Make sure restore ops are run
logger
.
info
(
"Weights or buffers not loaded from PyTorch model: {}"
.
format
(
all_pytorch_weights
))
return
tf_model
def
gelu
(
x
):
def
gelu
(
x
):
...
@@ -282,7 +218,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
...
@@ -282,7 +218,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
self
.
h
=
[
TFBlock
(
config
.
n_ctx
,
self
.
h
=
[
TFBlock
(
config
.
n_ctx
,
config
,
config
,
scale
=
True
,
scale
=
True
,
name
=
'h__{}'
.
format
(
i
))
for
i
in
range
(
config
.
n_layer
)]
name
=
'h_
.
_{}'
.
format
(
i
))
for
i
in
range
(
config
.
n_layer
)]
self
.
ln_f
=
tf
.
keras
.
layers
.
LayerNormalization
(
epsilon
=
config
.
layer_norm_epsilon
,
name
=
'ln_f'
)
self
.
ln_f
=
tf
.
keras
.
layers
.
LayerNormalization
(
epsilon
=
config
.
layer_norm_epsilon
,
name
=
'ln_f'
)
def
_resize_token_embeddings
(
self
,
new_num_tokens
):
def
_resize_token_embeddings
(
self
,
new_num_tokens
):
...
...
pytorch_transformers/modeling_tf_pytorch_utils.py
0 → 100644
View file @
646711e1
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch - TF 2.0 general utilities."""
from
__future__
import
(
absolute_import
,
division
,
print_function
,
unicode_literals
)
import
logging
from
pytorch_transformers
import
is_tf_available
,
is_torch_available
logger
=
logging
.
getLogger
(
__name__
)
def
load_pytorch_checkpoint_in_tf2_model
(
tf_model
,
pytorch_checkpoint_path
):
""" Load pytorch checkpoints in a TF 2.0 model
Conventions for TF2.0 scopes -> PyTorch attribute names conversions:
- '$1___$2' is replaced by $2 (can be used to duplicate or remove layers in TF2.0 vs PyTorch)
- '_._' is replaced by a new level separation (can be used to convert TF2.0 lists in PyTorch nn.ModulesList)
"""
if
not
is_tf_available
()
or
not
is_torch_available
():
logger
.
error
(
"Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see "
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
)
raise
ImportError
import
torch
pt_path
=
os
.
path
.
abspath
(
pytorch_checkpoint_path
)
logger
.
info
(
"Loading PyTorch weights from {}"
.
format
(
pt_path
))
pt_state_dict
=
torch
.
load
(
pt_path
,
map_location
=
'cpu'
)
return
load_pytorch_state_dict_in_tf2_model
(
tf_model
,
pt_state_dict
)
def
load_pytorch_state_dict_in_tf2_model
(
tf_model
,
pt_state_dict
):
""" Load pytorch state_dict in a TF 2.0 model.
Conventions for TF2.0 scopes -> PyTorch attribute names conversions:
- '$1___$2' is replaced by $2 (can be used to duplicate or remove layers in TF2.0 vs PyTorch)
- '_._' is replaced by a level separation (can be used to convert TF2.0 lists in PyTorch nn.ModulesList)
"""
try
:
import
re
import
torch
import
numpy
from
tensorflow.python.keras
import
backend
as
K
except
ImportError
as
e
:
logger
.
error
(
"Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see "
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
)
raise
e
symbolic_weights
=
tf_model
.
trainable_weights
+
tf_model
.
non_trainable_weights
weight_value_tuples
=
[]
all_pytorch_weights
=
set
(
list
(
pt_state_dict
.
keys
()))
for
symbolic_weight
in
symbolic_weights
:
name
=
symbolic_weight
.
name
name
=
name
.
replace
(
':0'
,
''
)
# device ids
name
=
re
.
sub
(
r
'/[^/]*___([^/]*)/'
,
r
'/\1/'
,
name
)
# '$1___$2' is replaced by $2 (can be used to duplicate or remove layers in TF2.0 vs PyTorch)
name
=
name
.
replace
(
'_._'
,
'/'
)
# '_._' is replaced by a level separation (can be used to convert TF2.0 lists in PyTorch nn.ModulesList)
name
=
re
.
sub
(
r
'//+'
,
'/'
,
name
)
# Remove empty levels at the end
name
=
name
.
split
(
'/'
)
# Convert from TF2.0 '/' separators to PyTorch '.' separators
name
=
name
[
1
:]
# Remove level zero
# Convert standard TF2.0 names in PyTorch names
transpose
=
bool
(
name
[
-
1
]
==
'kernel'
)
if
name
[
-
1
]
==
'kernel'
or
name
[
-
1
]
==
'embeddings'
or
name
[
-
1
]
==
'gamma'
:
name
[
-
1
]
=
'weight'
if
name
[
-
1
]
==
'beta'
:
name
[
-
1
]
=
'bias'
name
=
'.'
.
join
(
name
)
assert
name
in
pt_state_dict
,
"{} not found in PyTorch model"
.
format
(
name
)
array
=
pt_state_dict
[
name
].
numpy
()
if
transpose
:
array
=
numpy
.
transpose
(
array
)
try
:
assert
list
(
symbolic_weight
.
shape
)
==
list
(
array
.
shape
)
except
AssertionError
as
e
:
e
.
args
+=
(
symbolic_weight
.
shape
,
array
.
shape
)
raise
e
logger
.
info
(
"Initialize TF weight {}"
.
format
(
symbolic_weight
.
name
))
weight_value_tuples
.
append
((
symbolic_weight
,
array
))
all_pytorch_weights
.
discard
(
name
)
K
.
batch_set_value
(
weight_value_tuples
)
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
# Make sure restore ops are run
logger
.
info
(
"Weights or buffers not loaded from PyTorch model: {}"
.
format
(
all_pytorch_weights
))
return
tf_model
def
load_tf2_checkpoint_in_pytorch_model
(
pt_model
,
tf_checkpoint_path
):
""" Load TF 2.0 HDF5 checkpoint in a PyTorch model
We use HDF5 to easily do transfer learning
(see https://github.com/tensorflow/tensorflow/blob/ee16fcac960ae660e0e4496658a366e2f745e1f0/tensorflow/python/keras/engine/network.py#L1352-L1357).
"""
raise
NotImplementedError
def
load_tf2_weights_in_pytorch_model
(
pt_model
,
tf_model
):
""" Load TF2.0 symbolic weights in a PyTorch model
"""
raise
NotImplementedError
pytorch_transformers/modeling_tf_xlm.py
View file @
646711e1
...
@@ -18,6 +18,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera
...
@@ -18,6 +18,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import
logging
import
logging
import
math
import
math
import
os
import
itertools
import
itertools
import
numpy
as
np
import
numpy
as
np
...
@@ -26,6 +27,7 @@ import tensorflow as tf
...
@@ -26,6 +27,7 @@ import tensorflow as tf
from
.configuration_xlm
import
XLMConfig
from
.configuration_xlm
import
XLMConfig
from
.modeling_tf_utils
import
TFPreTrainedModel
,
TFSharedEmbeddings
,
TFSequenceSummary
,
shape_list
from
.modeling_tf_utils
import
TFPreTrainedModel
,
TFSharedEmbeddings
,
TFSequenceSummary
,
shape_list
from
.file_utils
import
add_start_docstrings
from
.file_utils
import
add_start_docstrings
from
.modeling_tf_pytorch_utils
import
load_pytorch_checkpoint_in_tf2_model
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -43,71 +45,16 @@ TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP = {
...
@@ -43,71 +45,16 @@ TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP = {
}
}
def
load_xlm_pt_weights_in_tf2
(
tf_model
,
config
,
pytorch_checkpoint_path
):
def
load_xlm_pt_weights_in_tf2
(
tf_model
,
pytorch_checkpoint_path
):
""" Load pytorch checkpoints in a TF 2.0 model and save it using HDF5 format
# build the network
We use HDF5 to easily do transfer learning
(see https://github.com/tensorflow/tensorflow/blob/ee16fcac960ae660e0e4496658a366e2f745e1f0/tensorflow/python/keras/engine/network.py#L1352-L1357).
"""
try
:
import
re
import
torch
import
numpy
from
tensorflow.python.keras
import
backend
as
K
except
ImportError
:
logger
.
error
(
"Loading a PyTorch model in TensorFlow, requires PyTorch to be installed. Please see "
"https://pytorch.org/ for installation instructions."
)
raise
pt_path
=
os
.
path
.
abspath
(
pytorch_checkpoint_path
)
logger
.
info
(
"Loading PyTorch weights from {}"
.
format
(
pt_path
))
# Load pytorch model
state_dict
=
torch
.
load
(
pt_path
,
map_location
=
'cpu'
)
inputs_list
=
[[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]]
inputs_list
=
[[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]]
attns_list
=
[[
1
,
1
,
0
,
0
,
1
],
[
1
,
1
,
1
,
0
,
0
],
[
1
,
0
,
0
,
1
,
1
]]
langs_list
=
[[
1
,
1
,
0
,
0
,
1
],
[
1
,
1
,
1
,
0
,
0
],
[
1
,
0
,
0
,
1
,
1
]]
tf_inputs
=
tf
.
constant
(
inputs_list
)
tf_inputs
=
tf
.
constant
(
inputs_list
)
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
# build the network
tf_attns
=
tf
.
constant
(
attns_list
)
tf_langs
=
tf
.
constant
(
langs_list
)
symbolic_weights
=
tf_model
.
trainable_weights
+
tf_model
.
non_trainable_weights
tfo
=
tf_model
([
tf_inputs
,
tf_attns
,
tf_langs
],
training
=
False
)
weight_value_tuples
=
[]
return
load_pytorch_checkpoint_in_tf2_model
(
tf_model
,
pytorch_checkpoint_path
)
all_pytorch_weights
=
set
(
list
(
state_dict
.
keys
()))
for
symbolic_weight
in
symbolic_weights
:
name
=
symbolic_weight
.
name
name
=
name
.
replace
(
':0'
,
''
)
name
=
name
.
replace
(
'__'
,
'/'
)
name
=
name
.
split
(
'/'
)
name
=
name
[
1
:]
transpose
=
bool
(
name
[
-
1
]
==
'kernel'
)
if
name
[
-
1
]
==
'kernel'
or
name
[
-
1
]
==
'embeddings'
or
name
[
-
1
]
==
'gamma'
:
name
[
-
1
]
=
'weight'
if
name
[
-
1
]
==
'beta'
:
name
[
-
1
]
=
'bias'
name
=
'.'
.
join
(
name
)
assert
name
in
state_dict
,
"{} not found in PyTorch model"
.
format
(
name
)
array
=
state_dict
[
name
].
numpy
()
if
transpose
:
array
=
numpy
.
transpose
(
array
)
try
:
assert
list
(
symbolic_weight
.
shape
)
==
list
(
array
.
shape
)
except
AssertionError
as
e
:
e
.
args
+=
(
symbolic_weight
.
shape
,
array
.
shape
)
raise
e
logger
.
info
(
"Initialize TF weight {}"
.
format
(
symbolic_weight
.
name
))
weight_value_tuples
.
append
((
symbolic_weight
,
array
))
all_pytorch_weights
.
discard
(
name
)
K
.
batch_set_value
(
weight_value_tuples
)
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
# Make sure restore ops are run
logger
.
info
(
"Weights or buffers not loaded from PyTorch model: {}"
.
format
(
all_pytorch_weights
))
return
tf_model
def
create_sinusoidal_embeddings
(
n_pos
,
dim
,
out
):
def
create_sinusoidal_embeddings
(
n_pos
,
dim
,
out
):
...
@@ -320,13 +267,13 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
...
@@ -320,13 +267,13 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
# self.encoder_attn = tf.keras.layers.LayerList()
# self.encoder_attn = tf.keras.layers.LayerList()
for
i
in
range
(
self
.
n_layers
):
for
i
in
range
(
self
.
n_layers
):
self
.
attentions
.
append
(
TFMultiHeadAttention
(
self
.
n_heads
,
self
.
dim
,
config
=
config
,
name
=
'attentions__{}'
.
format
(
i
)))
self
.
attentions
.
append
(
TFMultiHeadAttention
(
self
.
n_heads
,
self
.
dim
,
config
=
config
,
name
=
'attentions_
.
_{}'
.
format
(
i
)))
self
.
layer_norm1
.
append
(
tf
.
keras
.
layers
.
LayerNormalization
(
epsilon
=
config
.
layer_norm_eps
,
name
=
'layer_norm1__{}'
.
format
(
i
)))
self
.
layer_norm1
.
append
(
tf
.
keras
.
layers
.
LayerNormalization
(
epsilon
=
config
.
layer_norm_eps
,
name
=
'layer_norm1_
.
_{}'
.
format
(
i
)))
# if self.is_decoder:
# if self.is_decoder:
# self.layer_norm15.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps))
# self.layer_norm15.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps))
# self.encoder_attn.append(MultiHeadAttention(self.n_heads, self.dim, dropout=self.attention_dropout))
# self.encoder_attn.append(MultiHeadAttention(self.n_heads, self.dim, dropout=self.attention_dropout))
self
.
ffns
.
append
(
TFTransformerFFN
(
self
.
dim
,
self
.
hidden_dim
,
self
.
dim
,
config
=
config
,
name
=
'ffns__{}'
.
format
(
i
)))
self
.
ffns
.
append
(
TFTransformerFFN
(
self
.
dim
,
self
.
hidden_dim
,
self
.
dim
,
config
=
config
,
name
=
'ffns_
.
_{}'
.
format
(
i
)))
self
.
layer_norm2
.
append
(
tf
.
keras
.
layers
.
LayerNormalization
(
epsilon
=
config
.
layer_norm_eps
,
name
=
'layer_norm2__{}'
.
format
(
i
)))
self
.
layer_norm2
.
append
(
tf
.
keras
.
layers
.
LayerNormalization
(
epsilon
=
config
.
layer_norm_eps
,
name
=
'layer_norm2_
.
_{}'
.
format
(
i
)))
if
hasattr
(
config
,
"pruned_heads"
):
if
hasattr
(
config
,
"pruned_heads"
):
pruned_heads
=
config
.
pruned_heads
.
copy
().
items
()
pruned_heads
=
config
.
pruned_heads
.
copy
().
items
()
...
@@ -667,8 +614,8 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel):
...
@@ -667,8 +614,8 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel):
"""
"""
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
super
(
TFXLMWithLMHeadModel
,
self
).
__init__
(
config
,
*
inputs
,
**
kwargs
)
super
(
TFXLMWithLMHeadModel
,
self
).
__init__
(
config
,
*
inputs
,
**
kwargs
)
self
.
transformer
=
TFXLMMainLayer
(
config
,
name
=
'transformer'
)
self
.
transformer
=
TFXLMMainLayer
(
config
,
name
=
'transformer
___
'
)
self
.
pred_layer
=
TFXLMPredLayer
(
config
,
self
.
transformer
.
embeddings
,
name
=
'pred_layer'
)
self
.
pred_layer
=
TFXLMPredLayer
(
config
,
self
.
transformer
.
embeddings
,
name
=
'pred_layer
_._proj
'
)
def
call
(
self
,
inputs
,
training
=
False
):
def
call
(
self
,
inputs
,
training
=
False
):
...
...
pytorch_transformers/modeling_tf_xlnet.py
View file @
646711e1
...
@@ -30,6 +30,7 @@ import tensorflow as tf
...
@@ -30,6 +30,7 @@ import tensorflow as tf
from
.configuration_xlnet
import
XLNetConfig
from
.configuration_xlnet
import
XLNetConfig
from
.modeling_tf_utils
import
TFPreTrainedModel
,
TFSharedEmbeddings
,
TFSequenceSummary
,
shape_list
from
.modeling_tf_utils
import
TFPreTrainedModel
,
TFSharedEmbeddings
,
TFSequenceSummary
,
shape_list
from
.file_utils
import
add_start_docstrings
from
.file_utils
import
add_start_docstrings
from
.modeling_tf_pytorch_utils
import
load_pytorch_checkpoint_in_tf2_model
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -40,71 +41,11 @@ TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP = {
...
@@ -40,71 +41,11 @@ TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP = {
}
}
def
load_xlnet_pt_weights_in_tf2
(
tf_model
,
config
,
pytorch_checkpoint_path
):
def
load_xlnet_pt_weights_in_tf2
(
tf_model
,
pytorch_checkpoint_path
):
""" Load pytorch checkpoints in a TF 2.0 model and save it using HDF5 format
We use HDF5 to easily do transfer learning
(see https://github.com/tensorflow/tensorflow/blob/ee16fcac960ae660e0e4496658a366e2f745e1f0/tensorflow/python/keras/engine/network.py#L1352-L1357).
"""
try
:
import
re
import
torch
import
numpy
from
tensorflow.python.keras
import
backend
as
K
except
ImportError
:
logger
.
error
(
"Loading a PyTorch model in TensorFlow, requires PyTorch to be installed. Please see "
"https://pytorch.org/ for installation instructions."
)
raise
pt_path
=
os
.
path
.
abspath
(
pytorch_checkpoint_path
)
logger
.
info
(
"Loading PyTorch weights from {}"
.
format
(
pt_path
))
# Load pytorch model
state_dict
=
torch
.
load
(
pt_path
,
map_location
=
'cpu'
)
inputs_list
=
[[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]]
inputs_list
=
[[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]]
tf_inputs
=
tf
.
constant
(
inputs_list
)
tf_inputs
=
tf
.
constant
(
inputs_list
)
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
# build the network
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
# build the network
return
load_pytorch_checkpoint_in_tf2_model
(
tf_model
,
pytorch_checkpoint_path
)
symbolic_weights
=
tf_model
.
trainable_weights
+
tf_model
.
non_trainable_weights
weight_value_tuples
=
[]
all_pytorch_weights
=
set
(
list
(
state_dict
.
keys
()))
for
symbolic_weight
in
symbolic_weights
:
name
=
symbolic_weight
.
name
name
=
name
.
replace
(
':0'
,
''
)
name
=
name
.
replace
(
'__'
,
'/'
)
name
=
name
.
split
(
'/'
)
name
=
name
[
1
:]
transpose
=
bool
(
name
[
-
1
]
==
'kernel'
)
if
name
[
-
1
]
==
'kernel'
or
name
[
-
1
]
==
'embeddings'
or
name
[
-
1
]
==
'gamma'
:
name
[
-
1
]
=
'weight'
if
name
[
-
1
]
==
'beta'
:
name
[
-
1
]
=
'bias'
name
=
'.'
.
join
(
name
)
assert
name
in
state_dict
,
"{} not found in PyTorch model"
.
format
(
name
)
array
=
state_dict
[
name
].
numpy
()
if
transpose
:
array
=
numpy
.
transpose
(
array
)
try
:
assert
list
(
symbolic_weight
.
shape
)
==
list
(
array
.
shape
)
except
AssertionError
as
e
:
e
.
args
+=
(
symbolic_weight
.
shape
,
array
.
shape
)
raise
e
logger
.
info
(
"Initialize TF weight {}"
.
format
(
symbolic_weight
.
name
))
weight_value_tuples
.
append
((
symbolic_weight
,
array
))
all_pytorch_weights
.
discard
(
name
)
K
.
batch_set_value
(
weight_value_tuples
)
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
# Make sure restore ops are run
logger
.
info
(
"Weights or buffers not loaded from PyTorch model: {}"
.
format
(
all_pytorch_weights
))
return
tf_model
def
gelu
(
x
):
def
gelu
(
x
):
...
@@ -430,7 +371,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
...
@@ -430,7 +371,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
self
.
initializer_range
=
config
.
initializer_range
self
.
initializer_range
=
config
.
initializer_range
self
.
word_embedding
=
TFSharedEmbeddings
(
config
.
n_token
,
config
.
d_model
,
initializer_range
=
config
.
initializer_range
,
name
=
'word_embedding'
)
self
.
word_embedding
=
TFSharedEmbeddings
(
config
.
n_token
,
config
.
d_model
,
initializer_range
=
config
.
initializer_range
,
name
=
'word_embedding'
)
self
.
layer
=
[
TFXLNetLayer
(
config
,
name
=
'layer__{}'
.
format
(
i
))
for
i
in
range
(
config
.
n_layer
)]
self
.
layer
=
[
TFXLNetLayer
(
config
,
name
=
'layer_
.
_{}'
.
format
(
i
))
for
i
in
range
(
config
.
n_layer
)]
self
.
dropout
=
tf
.
keras
.
layers
.
Dropout
(
config
.
dropout
)
self
.
dropout
=
tf
.
keras
.
layers
.
Dropout
(
config
.
dropout
)
def
build
(
self
,
input_shape
):
def
build
(
self
,
input_shape
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment