Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a7e01a24
"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "304aacac90ea6df8f3bfc2956a0ae6137f690bc0"
Commit
a7e01a24
authored
Sep 24, 2019
by
thomwolf
Browse files
converting distilled/fine-tuned models
parent
8ba44ced
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
90 additions
and
39 deletions
+90
-39
pytorch_transformers/__init__.py
pytorch_transformers/__init__.py
+1
-1
pytorch_transformers/convert_pytorch_checkpoint_to_tf2.py
pytorch_transformers/convert_pytorch_checkpoint_to_tf2.py
+69
-34
pytorch_transformers/modeling_tf_distilbert.py
pytorch_transformers/modeling_tf_distilbert.py
+3
-3
pytorch_transformers/modeling_tf_pytorch_utils.py
pytorch_transformers/modeling_tf_pytorch_utils.py
+17
-1
No files found.
pytorch_transformers/__init__.py
View file @
a7e01a24
...
...
@@ -146,7 +146,7 @@ if _tf_available:
from
.modeling_tf_distilbert
import
(
TFDistilBertPreTrainedModel
,
TFDistilBertMainLayer
,
TFDistilBertModel
,
TFDistilBertForMaskedLM
,
TFDistilBertForSequenceClassification
,
TFDistilBertFor
SequenceClassification
,
TFDistilBertFor
QuestionAnswering
,
load_distilbert_pt_weights_in_tf2
,
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
)
...
...
pytorch_transformers/convert_pytorch_checkpoint_to_tf2.py
View file @
a7e01a24
...
...
@@ -24,43 +24,43 @@ import tensorflow as tf
from
pytorch_transformers
import
is_torch_available
,
cached_path
from
pytorch_transformers
import
(
BertConfig
,
TFBertForPreTraining
,
load_bert_pt_weights_in_tf2
,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
from
pytorch_transformers
import
(
BertConfig
,
TFBertForPreTraining
,
TFBertForQuestionAnswering
,
TFBertForSequenceClassification
,
load_bert_pt_weights_in_tf2
,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
GPT2Config
,
TFGPT2LMHeadModel
,
load_gpt2_pt_weights_in_tf2
,
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
,
XLNetConfig
,
TFXLNetLMHeadModel
,
load_xlnet_pt_weights_in_tf2
,
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
,
XLMConfig
,
TFXLMWithLMHeadModel
,
load_xlm_pt_weights_in_tf2
,
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
,
TransfoXLConfig
,
TFTransfoXLLMHeadModel
,
load_transfo_xl_pt_weights_in_tf2
,
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP
,
OpenAIGPTConfig
,
TFOpenAIGPTLMHeadModel
,
load_openai_gpt_pt_weights_in_tf2
,
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
RobertaConfig
,
TFRobertaForMaskedLM
,
load_roberta_pt_weights_in_tf2
,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
,
DistilBertConfig
,
TFDistilBertForMaskedLM
,
load_distilbert_pt_weights_in_tf2
,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
)
RobertaConfig
,
TFRobertaForMaskedLM
,
TFRobertaForSequenceClassification
,
load_roberta_pt_weights_in_tf2
,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
,
DistilBertConfig
,
TFDistilBertForMaskedLM
,
TFDistilBertForQuestionAnswering
,
load_distilbert_pt_weights_in_tf2
,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
)
if
is_torch_available
():
import
torch
import
numpy
as
np
from
pytorch_transformers
import
(
BertForPreTraining
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
from
pytorch_transformers
import
(
BertForPreTraining
,
BertForQuestionAnswering
,
BertForSequenceClassification
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
GPT2LMHeadModel
,
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
,
XLNetLMHeadModel
,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
,
XLMWithLMHeadModel
,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP
,
TransfoXLLMHeadModel
,
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
,
OpenAIGPTLMHeadModel
,
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
,
RobertaForMaskedLM
,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
,
DistilBertForMaskedLM
,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
)
RobertaForMaskedLM
,
RobertaForSequenceClassification
,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
,
DistilBertForMaskedLM
,
DistilBertForQuestionAnswering
,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
)
else
:
(
BertForPreTraining
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
(
BertForPreTraining
,
BertForQuestionAnswering
,
BertForSequenceClassification
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
GPT2LMHeadModel
,
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
,
XLNetLMHeadModel
,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
,
XLMWithLMHeadModel
,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP
,
TransfoXLLMHeadModel
,
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
,
OpenAIGPTLMHeadModel
,
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
,
RobertaForMaskedLM
,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
,
DistilBertForMaskedLM
,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,)
=
(
RobertaForMaskedLM
,
RobertaForSequenceClassification
,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
,
DistilBertForMaskedLM
,
DistilBertForQuestionAnswering
,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,)
=
(
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,)
None
,
None
,
None
,
None
,
None
,
None
,)
import
logging
...
...
@@ -68,22 +68,29 @@ logging.basicConfig(level=logging.INFO)
MODEL_CLASSES
=
{
'bert'
:
(
BertConfig
,
TFBertForPreTraining
,
load_bert_pt_weights_in_tf2
,
BertForPreTraining
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'bert-large-uncased-whole-word-masking-finetuned-squad'
:
(
BertConfig
,
TFBertForQuestionAnswering
,
load_bert_pt_weights_in_tf2
,
BertForQuestionAnswering
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'bert-large-cased-whole-word-masking-finetuned-squad'
:
(
BertConfig
,
TFBertForQuestionAnswering
,
load_bert_pt_weights_in_tf2
,
BertForQuestionAnswering
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'bert-base-cased-finetuned-mrpc'
:
(
BertConfig
,
TFBertForSequenceClassification
,
load_bert_pt_weights_in_tf2
,
BertForSequenceClassification
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'gpt2'
:
(
GPT2Config
,
TFGPT2LMHeadModel
,
load_gpt2_pt_weights_in_tf2
,
GPT2LMHeadModel
,
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
,
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'xlnet'
:
(
XLNetConfig
,
TFXLNetLMHeadModel
,
load_xlnet_pt_weights_in_tf2
,
XLNetLMHeadModel
,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
,
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'xlm'
:
(
XLMConfig
,
TFXLMWithLMHeadModel
,
load_xlm_pt_weights_in_tf2
,
XLMWithLMHeadModel
,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP
,
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'transfo-xl'
:
(
TransfoXLConfig
,
TFTransfoXLLMHeadModel
,
load_transfo_xl_pt_weights_in_tf2
,
TransfoXLLMHeadModel
,
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
,
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'openai-gpt'
:
(
OpenAIGPTConfig
,
TFOpenAIGPTLMHeadModel
,
load_openai_gpt_pt_weights_in_tf2
,
OpenAIGPTLMHeadModel
,
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
,
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'roberta'
:
(
RobertaConfig
,
TFRobertaForMaskedLM
,
load_roberta_pt_weights_in_tf2
,
RobertaForMaskedLM
,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'roberta-large-mnli'
:
(
RobertaConfig
,
TFRobertaForSequenceClassification
,
load_roberta_pt_weights_in_tf2
,
RobertaForSequenceClassification
,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'distilbert'
:
(
DistilBertConfig
,
TFDistilBertForMaskedLM
,
load_distilbert_pt_weights_in_tf2
,
DistilBertForMaskedLM
,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'distilbert-base-uncased-distilled-squad'
:
(
DistilBertConfig
,
TFDistilBertForQuestionAnswering
,
load_distilbert_pt_weights_in_tf2
,
DistilBertForQuestionAnswering
,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
}
def
convert_pt_checkpoint_to_tf
(
model_type
,
pytorch_checkpoint_path
,
config_file
,
tf_dump_path
,
compare_with_pt_model
=
False
):
def
convert_pt_checkpoint_to_tf
(
model_type
,
pytorch_checkpoint_path
,
config_file
,
tf_dump_path
,
compare_with_pt_model
=
False
,
use_cached_models
=
True
):
if
model_type
not
in
MODEL_CLASSES
:
raise
ValueError
(
"Unrecognized model type, should be one of {}."
.
format
(
list
(
MODEL_CLASSES
.
keys
())))
config_class
,
model_class
,
loading_fct
,
pt_model_class
,
aws_model_maps
,
aws_config_map
=
MODEL_CLASSES
[
model_type
]
# Initialise TF model
if
config_file
in
aws_config_map
:
config_file
=
cached_path
(
aws_config_map
[
config_file
],
force_download
=
not
use_cached_models
)
config
=
config_class
.
from_json_file
(
config_file
)
config
.
output_hidden_states
=
True
config
.
output_attentions
=
True
...
...
@@ -91,6 +98,8 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
tf_model
=
model_class
(
config
)
# Load weights from tf checkpoint
if
pytorch_checkpoint_path
in
aws_model_maps
:
pytorch_checkpoint_path
=
cached_path
(
aws_model_maps
[
pytorch_checkpoint_path
],
force_download
=
not
use_cached_models
)
tf_model
=
loading_fct
(
tf_model
,
pytorch_checkpoint_path
)
if
compare_with_pt_model
:
...
...
@@ -117,7 +126,8 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
tf_model
.
save_weights
(
tf_dump_path
,
save_format
=
'h5'
)
def
convert_all_pt_checkpoints_to_tf
(
args_model_type
,
tf_dump_path
,
compare_with_pt_model
=
False
,
use_cached_models
=
False
):
def
convert_all_pt_checkpoints_to_tf
(
args_model_type
,
tf_dump_path
,
model_shortcut_names_or_path
=
None
,
config_shortcut_names_or_path
=
None
,
compare_with_pt_model
=
False
,
use_cached_models
=
False
,
only_convert_finetuned_models
=
False
):
assert
os
.
path
.
isdir
(
args
.
tf_dump_path
),
"--tf_dump_path should be a directory"
if
args_model_type
is
None
:
...
...
@@ -134,20 +144,39 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, compare_with
config_class
,
model_class
,
loading_fct
,
pt_model_class
,
aws_model_maps
,
aws_config_map
=
MODEL_CLASSES
[
model_type
]
for
i
,
shortcut_name
in
enumerate
(
aws_config_map
.
keys
(),
start
=
1
):
print
(
"-"
*
100
)
print
(
" Converting checkpoint {}/{}: {}"
.
format
(
i
,
len
(
aws_config_map
),
shortcut_name
))
if
model_shortcut_names_or_path
is
None
:
model_shortcut_names_or_path
=
list
(
aws_model_maps
.
keys
())
if
config_shortcut_names_or_path
is
None
:
config_shortcut_names_or_path
=
model_shortcut_names_or_path
for
i
,
(
model_shortcut_name
,
config_shortcut_name
)
in
enumerate
(
zip
(
model_shortcut_names_or_path
,
config_shortcut_names_or_path
),
start
=
1
):
print
(
"-"
*
100
)
if
'finetuned'
in
shortcut_name
:
print
(
" Skipping finetuned checkpoint "
)
if
'-squad'
in
model_shortcut_name
or
'-mrpc'
in
model_shortcut_name
or
'-mnli'
in
model_shortcut_name
:
if
not
only_convert_finetuned_models
:
print
(
" Skipping finetuned checkpoint {}"
.
format
(
model_shortcut_name
))
continue
model_type
=
model_shortcut_name
elif
only_convert_finetuned_models
:
print
(
" Skipping not finetuned checkpoint {}"
.
format
(
model_shortcut_name
))
continue
config_file
=
cached_path
(
aws_config_map
[
shortcut_name
],
force_download
=
not
use_cached_models
)
model_file
=
cached_path
(
aws_model_maps
[
shortcut_name
],
force_download
=
not
use_cached_models
)
print
(
" Converting checkpoint {}/{}: {} - model_type {}"
.
format
(
i
,
len
(
aws_config_map
),
model_shortcut_name
,
model_type
))
print
(
"-"
*
100
)
if
config_shortcut_name
in
aws_config_map
:
config_file
=
cached_path
(
aws_config_map
[
config_shortcut_name
],
force_download
=
not
use_cached_models
)
else
:
config_file
=
cached_path
(
config_shortcut_name
,
force_download
=
not
use_cached_models
)
if
model_shortcut_name
in
aws_model_maps
:
model_file
=
cached_path
(
aws_model_maps
[
model_shortcut_name
],
force_download
=
not
use_cached_models
)
else
:
model_file
=
cached_path
(
model_shortcut_name
,
force_download
=
not
use_cached_models
)
convert_pt_checkpoint_to_tf
(
model_type
,
model_file
,
config_file
,
os
.
path
.
join
(
tf_dump_path
,
shortcut_name
+
'-tf_model.h5'
),
os
.
path
.
join
(
tf_dump_path
,
model_
shortcut_name
+
'-tf_model.h5'
),
compare_with_pt_model
=
compare_with_pt_model
)
os
.
remove
(
config_file
)
os
.
remove
(
model_file
)
...
...
@@ -176,23 +205,29 @@ if __name__ == "__main__":
help
=
"The config json file corresponding to the pre-trained model.
\n
"
"This specifies the model architecture. If not given and "
"--pytorch_checkpoint_path is not given or is a shortcut name"
"use the configuration associated to t
e
h shortcut name on the AWS"
)
"use the configuration associated to th
e
shortcut name on the AWS"
)
parser
.
add_argument
(
"--compare_with_pt_model"
,
action
=
'store_true'
,
help
=
"Compare Tensorflow and PyTorch model predictions."
)
parser
.
add_argument
(
"--use_cached_models"
,
action
=
'store_true'
,
help
=
"Use cached models if possible instead of updating to latest checkpoint versions."
)
parser
.
add_argument
(
"--only_convert_finetuned_models"
,
action
=
'store_true'
,
help
=
"Only convert finetuned models."
)
args
=
parser
.
parse_args
()
if
args
.
pytorch_checkpoint_path
is
not
None
:
convert_pt_checkpoint_to_tf
(
args
.
model_type
.
lower
(),
args
.
pytorch_checkpoint_path
,
args
.
config_file
,
args
.
tf_dump_path
,
compare_with_pt_model
=
args
.
compare_with_pt_model
)
else
:
convert_all_pt_checkpoints_to_tf
(
args
.
model_type
.
lower
()
if
args
.
model_type
is
not
None
else
None
,
args
.
tf_dump_path
,
compare_with_pt_model
=
args
.
compare_with_pt_model
,
use_cached_models
=
args
.
use_cached_models
)
# if args.pytorch_checkpoint_path is not None:
# convert_pt_checkpoint_to_tf(args.model_type.lower(),
# args.pytorch_checkpoint_path,
# args.config_file if args.config_file is not None else args.pytorch_checkpoint_path,
# args.tf_dump_path,
# compare_with_pt_model=args.compare_with_pt_model,
# use_cached_models=args.use_cached_models)
# else:
convert_all_pt_checkpoints_to_tf
(
args
.
model_type
.
lower
()
if
args
.
model_type
is
not
None
else
None
,
args
.
tf_dump_path
,
model_shortcut_names_or_path
=
[
args
.
pytorch_checkpoint_path
]
if
args
.
pytorch_checkpoint_path
is
not
None
else
None
,
compare_with_pt_model
=
args
.
compare_with_pt_model
,
use_cached_models
=
args
.
use_cached_models
,
only_convert_finetuned_models
=
args
.
only_convert_finetuned_models
)
pytorch_transformers/modeling_tf_distilbert.py
View file @
a7e01a24
...
...
@@ -653,7 +653,7 @@ class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel):
super
(
TFDistilBertForSequenceClassification
,
self
).
__init__
(
config
,
*
inputs
,
**
kwargs
)
self
.
num_labels
=
config
.
num_labels
self
.
distilbert
=
TFDistilBertM
odel
(
config
,
name
=
"distilbert"
)
self
.
distilbert
=
TFDistilBertM
ainLayer
(
config
,
name
=
"distilbert"
)
self
.
pre_classifier
=
tf
.
keras
.
layers
.
Dense
(
config
.
dim
,
activation
=
'relu'
,
name
=
"pre_classifier"
)
self
.
classifier
=
tf
.
keras
.
layers
.
Dense
(
config
.
num_labels
,
name
=
"classifier"
)
self
.
dropout
=
tf
.
keras
.
layers
.
Dropout
(
config
.
seq_classif_dropout
)
...
...
@@ -714,8 +714,8 @@ class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel):
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
super
(
TFDistilBertForQuestionAnswering
,
self
).
__init__
(
config
,
*
inputs
,
**
kwargs
)
self
.
distilbert
=
TFDistilBertM
odel
(
config
,
name
=
"distilbert"
)
self
.
qa_outputs
=
tf
.
keras
.
layers
.
Dense
(
config
.
num_labels
,
name
=
'qa_output'
)
self
.
distilbert
=
TFDistilBertM
ainLayer
(
config
,
name
=
"distilbert"
)
self
.
qa_outputs
=
tf
.
keras
.
layers
.
Dense
(
config
.
num_labels
,
name
=
'qa_output
s
'
)
assert
config
.
num_labels
==
2
self
.
dropout
=
tf
.
keras
.
layers
.
Dropout
(
config
.
qa_dropout
)
...
...
pytorch_transformers/modeling_tf_pytorch_utils.py
View file @
a7e01a24
...
...
@@ -148,8 +148,24 @@ def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path):
""" Load TF 2.0 HDF5 checkpoint in a PyTorch model
We use HDF5 to easily do transfer learning
(see https://github.com/tensorflow/tensorflow/blob/ee16fcac960ae660e0e4496658a366e2f745e1f0/tensorflow/python/keras/engine/network.py#L1352-L1357).
Conventions for TF2.0 scopes -> PyTorch attribute names conversions:
- '$1___$2' is replaced by $2 (can be used to duplicate or remove layers in TF2.0 vs PyTorch)
- '_._' is replaced by a new level separation (can be used to convert TF2.0 lists in PyTorch nn.ModulesList)
"""
raise
NotImplementedError
try
:
import
tensorflow
as
tf
import
torch
except
ImportError
as
e
:
logger
.
error
(
"Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see "
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
)
raise
e
tf_path
=
os
.
path
.
abspath
(
tf_checkpoint_path
)
logger
.
info
(
"Loading TensorFlow weights from {}"
.
format
(
tf_path
))
tf_state_dict
=
torch
.
load
(
tf_path
,
map_location
=
'cpu'
)
return
load_tf2_weights_in_pytorch_model
(
pt_model
,
tf_state_dict
)
def
load_tf2_weights_in_pytorch_model
(
pt_model
,
tf_model
):
""" Load TF2.0 symbolic weights in a PyTorch model
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment