Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
da26bae6
Commit
da26bae6
authored
Oct 10, 2019
by
thomwolf
Browse files
adding more tests on TF and pytorch serialization - updating configuration for better serialization
parent
bb04edb4
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
90 additions
and
148 deletions
+90
-148
transformers/__init__.py
transformers/__init__.py
+7
-17
transformers/configuration_utils.py
transformers/configuration_utils.py
+2
-2
transformers/convert_pytorch_checkpoint_to_tf2.py
transformers/convert_pytorch_checkpoint_to_tf2.py
+28
-26
transformers/modeling_tf_bert.py
transformers/modeling_tf_bert.py
+0
-10
transformers/modeling_tf_ctrl.py
transformers/modeling_tf_ctrl.py
+0
-10
transformers/modeling_tf_distilbert.py
transformers/modeling_tf_distilbert.py
+0
-10
transformers/modeling_tf_gpt2.py
transformers/modeling_tf_gpt2.py
+0
-10
transformers/modeling_tf_openai.py
transformers/modeling_tf_openai.py
+0
-10
transformers/modeling_tf_pytorch_utils.py
transformers/modeling_tf_pytorch_utils.py
+1
-3
transformers/modeling_tf_roberta.py
transformers/modeling_tf_roberta.py
+0
-10
transformers/modeling_tf_transfo_xl.py
transformers/modeling_tf_transfo_xl.py
+0
-10
transformers/modeling_tf_utils.py
transformers/modeling_tf_utils.py
+6
-5
transformers/modeling_tf_xlm.py
transformers/modeling_tf_xlm.py
+12
-16
transformers/modeling_tf_xlnet.py
transformers/modeling_tf_xlnet.py
+0
-9
transformers/tests/modeling_common_test.py
transformers/tests/modeling_common_test.py
+34
-0
No files found.
transformers/__init__.py
View file @
da26bae6
...
@@ -110,65 +110,55 @@ if is_tf_available():
...
@@ -110,65 +110,55 @@ if is_tf_available():
TFBertForMaskedLM
,
TFBertForNextSentencePrediction
,
TFBertForMaskedLM
,
TFBertForNextSentencePrediction
,
TFBertForSequenceClassification
,
TFBertForMultipleChoice
,
TFBertForSequenceClassification
,
TFBertForMultipleChoice
,
TFBertForTokenClassification
,
TFBertForQuestionAnswering
,
TFBertForTokenClassification
,
TFBertForQuestionAnswering
,
load_bert_pt_weights_in_tf2
,
TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP
)
TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_tf_gpt2
import
(
TFGPT2PreTrainedModel
,
TFGPT2MainLayer
,
from
.modeling_tf_gpt2
import
(
TFGPT2PreTrainedModel
,
TFGPT2MainLayer
,
TFGPT2Model
,
TFGPT2LMHeadModel
,
TFGPT2DoubleHeadsModel
,
TFGPT2Model
,
TFGPT2LMHeadModel
,
TFGPT2DoubleHeadsModel
,
load_gpt2_pt_weights_in_tf2
,
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
)
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_tf_openai
import
(
TFOpenAIGPTPreTrainedModel
,
TFOpenAIGPTMainLayer
,
from
.modeling_tf_openai
import
(
TFOpenAIGPTPreTrainedModel
,
TFOpenAIGPTMainLayer
,
TFOpenAIGPTModel
,
TFOpenAIGPTLMHeadModel
,
TFOpenAIGPTDoubleHeadsModel
,
TFOpenAIGPTModel
,
TFOpenAIGPTLMHeadModel
,
TFOpenAIGPTDoubleHeadsModel
,
load_openai_gpt_pt_weights_in_tf2
,
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
)
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_tf_transfo_xl
import
(
TFTransfoXLPreTrainedModel
,
TFTransfoXLMainLayer
,
from
.modeling_tf_transfo_xl
import
(
TFTransfoXLPreTrainedModel
,
TFTransfoXLMainLayer
,
TFTransfoXLModel
,
TFTransfoXLLMHeadModel
,
TFTransfoXLModel
,
TFTransfoXLLMHeadModel
,
load_transfo_xl_pt_weights_in_tf2
,
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
)
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_tf_xlnet
import
(
TFXLNetPreTrainedModel
,
TFXLNetMainLayer
,
from
.modeling_tf_xlnet
import
(
TFXLNetPreTrainedModel
,
TFXLNetMainLayer
,
TFXLNetModel
,
TFXLNetLMHeadModel
,
TFXLNetModel
,
TFXLNetLMHeadModel
,
TFXLNetForSequenceClassification
,
TFXLNetForSequenceClassification
,
TFXLNetForQuestionAnsweringSimple
,
TFXLNetForQuestionAnsweringSimple
,
load_xlnet_pt_weights_in_tf2
,
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
)
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_tf_xlm
import
(
TFXLMPreTrainedModel
,
TFXLMMainLayer
,
from
.modeling_tf_xlm
import
(
TFXLMPreTrainedModel
,
TFXLMMainLayer
,
TFXLMModel
,
TFXLMWithLMHeadModel
,
TFXLMModel
,
TFXLMWithLMHeadModel
,
TFXLMForSequenceClassification
,
TFXLMForSequenceClassification
,
TFXLMForQuestionAnsweringSimple
,
TFXLMForQuestionAnsweringSimple
,
load_xlm_pt_weights_in_tf2
,
TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP
)
TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_tf_roberta
import
(
TFRobertaPreTrainedModel
,
TFRobertaMainLayer
,
from
.modeling_tf_roberta
import
(
TFRobertaPreTrainedModel
,
TFRobertaMainLayer
,
TFRobertaModel
,
TFRobertaForMaskedLM
,
TFRobertaModel
,
TFRobertaForMaskedLM
,
TFRobertaForSequenceClassification
,
TFRobertaForSequenceClassification
,
load_roberta_pt_weights_in_tf2
,
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
)
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_tf_distilbert
import
(
TFDistilBertPreTrainedModel
,
TFDistilBertMainLayer
,
from
.modeling_tf_distilbert
import
(
TFDistilBertPreTrainedModel
,
TFDistilBertMainLayer
,
TFDistilBertModel
,
TFDistilBertForMaskedLM
,
TFDistilBertModel
,
TFDistilBertForMaskedLM
,
TFDistilBertForSequenceClassification
,
TFDistilBertForSequenceClassification
,
TFDistilBertForQuestionAnswering
,
TFDistilBertForQuestionAnswering
,
load_distilbert_pt_weights_in_tf2
,
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
)
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_tf_ctrl
import
(
TFCTRLPreTrainedModel
,
TFCTRLModel
,
from
.modeling_tf_ctrl
import
(
TFCTRLPreTrainedModel
,
TFCTRLModel
,
TFCTRLLMHeadModel
,
TFCTRLLMHeadModel
,
load_ctrl_pt_weights_in_tf2
,
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
)
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
)
# TF 2.0 <=> PyTorch conversion utilities
# TF 2.0 <=> PyTorch conversion utilities
if
is_tf_available
()
and
is_torch_available
():
from
.modeling_tf_pytorch_utils
import
(
convert_tf_weight_name_to_pt_weight_name
,
from
.modeling_tf_pytorch_utils
import
(
convert_tf_weight_name_to_pt_weight_name
,
load_pytorch_checkpoint_in_tf2_model
,
load_pytorch_checkpoint_in_tf2_model
,
load_pytorch_weights_in_tf2_model
,
load_pytorch_weights_in_tf2_model
,
load_pytorch_model_in_tf2_model
,
load_pytorch_model_in_tf2_model
,
load_tf2_checkpoint_in_pytorch_model
,
load_tf2_checkpoint_in_pytorch_model
,
load_tf2_weights_in_pytorch_model
,
load_tf2_weights_in_pytorch_model
,
load_tf2_model_in_pytorch_model
)
load_tf2_model_in_pytorch_model
)
if
not
is_tf_available
()
and
not
is_torch_available
():
if
not
is_tf_available
()
and
not
is_torch_available
():
logger
.
warning
(
"Neither PyTorch nor TensorFlow >= 2.0 have been found."
logger
.
warning
(
"Neither PyTorch nor TensorFlow >= 2.0 have been found."
...
...
transformers/configuration_utils.py
View file @
da26bae6
...
@@ -153,7 +153,7 @@ class PretrainedConfig(object):
...
@@ -153,7 +153,7 @@ class PretrainedConfig(object):
config
=
cls
.
from_json_file
(
resolved_config_file
)
config
=
cls
.
from_json_file
(
resolved_config_file
)
if
hasattr
(
config
,
'pruned_heads'
):
if
hasattr
(
config
,
'pruned_heads'
):
config
.
pruned_heads
=
dict
((
int
(
key
),
set
(
value
)
)
for
key
,
value
in
config
.
pruned_heads
.
items
())
config
.
pruned_heads
=
dict
((
int
(
key
),
value
)
for
key
,
value
in
config
.
pruned_heads
.
items
())
# Update config with kwargs if needed
# Update config with kwargs if needed
to_remove
=
[]
to_remove
=
[]
...
@@ -164,7 +164,7 @@ class PretrainedConfig(object):
...
@@ -164,7 +164,7 @@ class PretrainedConfig(object):
for
key
in
to_remove
:
for
key
in
to_remove
:
kwargs
.
pop
(
key
,
None
)
kwargs
.
pop
(
key
,
None
)
logger
.
info
(
"Model config %s"
,
config
)
logger
.
info
(
"Model config %s"
,
str
(
config
)
)
if
return_unused_kwargs
:
if
return_unused_kwargs
:
return
config
,
kwargs
return
config
,
kwargs
else
:
else
:
...
...
transformers/convert_pytorch_checkpoint_to_tf2.py
View file @
da26bae6
...
@@ -24,15 +24,16 @@ import tensorflow as tf
...
@@ -24,15 +24,16 @@ import tensorflow as tf
from
transformers
import
is_torch_available
,
cached_path
from
transformers
import
is_torch_available
,
cached_path
from
transformers
import
(
BertConfig
,
TFBertForPreTraining
,
TFBertForQuestionAnswering
,
TFBertForSequenceClassification
,
load_bert_pt_weights_in_tf2
,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
from
transformers
import
(
load_pytorch_checkpoint_in_tf2_model
,
GPT2Config
,
TFGPT2LMHeadModel
,
load_gpt2_pt_weights_in_tf2
,
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
,
BertConfig
,
TFBertForPreTraining
,
TFBertForQuestionAnswering
,
TFBertForSequenceClassification
,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
XLNetConfig
,
TFXLNetLMHeadModel
,
load_xlnet_pt_weights_in_tf2
,
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
,
GPT2Config
,
TFGPT2LMHeadModel
,
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
,
XLMConfig
,
TFXLMWithLMHeadModel
,
load_xlm_pt_weights_in_tf2
,
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
,
XLNetConfig
,
TFXLNetLMHeadModel
,
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
,
TransfoXLConfig
,
TFTransfoXLLMHeadModel
,
load_transfo_xl_pt_weights_in_tf2
,
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP
,
XLMConfig
,
TFXLMWithLMHeadModel
,
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
,
OpenAIGPTConfig
,
TFOpenAIGPTLMHeadModel
,
load_openai_gpt_pt_weights_in_tf2
,
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
TransfoXLConfig
,
TFTransfoXLLMHeadModel
,
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP
,
RobertaConfig
,
TFRobertaForMaskedLM
,
TFRobertaForSequenceClassification
,
load_roberta_pt_weights_in_tf2
,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
,
OpenAIGPTConfig
,
TFOpenAIGPTLMHeadModel
,
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
DistilBertConfig
,
TFDistilBertForMaskedLM
,
TFDistilBertForQuestionAnswering
,
load_distilbert_pt_weights_in_tf2
,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
RobertaConfig
,
TFRobertaForMaskedLM
,
TFRobertaForSequenceClassification
,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
,
CTRLConfig
,
TFCTRLLMHeadModel
,
load_ctrl_pt_weights_in_tf2
,
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP
)
DistilBertConfig
,
TFDistilBertForMaskedLM
,
TFDistilBertForQuestionAnswering
,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
CTRLConfig
,
TFCTRLLMHeadModel
,
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP
)
if
is_torch_available
():
if
is_torch_available
():
import
torch
import
torch
...
@@ -71,27 +72,27 @@ import logging
...
@@ -71,27 +72,27 @@ import logging
logging
.
basicConfig
(
level
=
logging
.
INFO
)
logging
.
basicConfig
(
level
=
logging
.
INFO
)
MODEL_CLASSES
=
{
MODEL_CLASSES
=
{
'bert'
:
(
BertConfig
,
TFBertForPreTraining
,
load_bert_pt_weights_in_tf2
,
BertForPreTraining
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'bert'
:
(
BertConfig
,
TFBertForPreTraining
,
BertForPreTraining
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'bert-large-uncased-whole-word-masking-finetuned-squad'
:
(
BertConfig
,
TFBertForQuestionAnswering
,
load_bert_pt_weights_in_tf2
,
BertForQuestionAnswering
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'bert-large-uncased-whole-word-masking-finetuned-squad'
:
(
BertConfig
,
TFBertForQuestionAnswering
,
BertForQuestionAnswering
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'bert-large-cased-whole-word-masking-finetuned-squad'
:
(
BertConfig
,
TFBertForQuestionAnswering
,
load_bert_pt_weights_in_tf2
,
BertForQuestionAnswering
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'bert-large-cased-whole-word-masking-finetuned-squad'
:
(
BertConfig
,
TFBertForQuestionAnswering
,
BertForQuestionAnswering
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'bert-base-cased-finetuned-mrpc'
:
(
BertConfig
,
TFBertForSequenceClassification
,
load_bert_pt_weights_in_tf2
,
BertForSequenceClassification
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'bert-base-cased-finetuned-mrpc'
:
(
BertConfig
,
TFBertForSequenceClassification
,
BertForSequenceClassification
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'gpt2'
:
(
GPT2Config
,
TFGPT2LMHeadModel
,
load_gpt2_pt_weights_in_tf2
,
GPT2LMHeadModel
,
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
,
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'gpt2'
:
(
GPT2Config
,
TFGPT2LMHeadModel
,
GPT2LMHeadModel
,
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
,
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'xlnet'
:
(
XLNetConfig
,
TFXLNetLMHeadModel
,
load_xlnet_pt_weights_in_tf2
,
XLNetLMHeadModel
,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
,
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'xlnet'
:
(
XLNetConfig
,
TFXLNetLMHeadModel
,
XLNetLMHeadModel
,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
,
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'xlm'
:
(
XLMConfig
,
TFXLMWithLMHeadModel
,
load_xlm_pt_weights_in_tf2
,
XLMWithLMHeadModel
,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP
,
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'xlm'
:
(
XLMConfig
,
TFXLMWithLMHeadModel
,
XLMWithLMHeadModel
,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP
,
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'transfo-xl'
:
(
TransfoXLConfig
,
TFTransfoXLLMHeadModel
,
load_transfo_xl_pt_weights_in_tf2
,
TransfoXLLMHeadModel
,
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
,
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'transfo-xl'
:
(
TransfoXLConfig
,
TFTransfoXLLMHeadModel
,
TransfoXLLMHeadModel
,
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
,
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'openai-gpt'
:
(
OpenAIGPTConfig
,
TFOpenAIGPTLMHeadModel
,
load_openai_gpt_pt_weights_in_tf2
,
OpenAIGPTLMHeadModel
,
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
,
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'openai-gpt'
:
(
OpenAIGPTConfig
,
TFOpenAIGPTLMHeadModel
,
OpenAIGPTLMHeadModel
,
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
,
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'roberta'
:
(
RobertaConfig
,
TFRobertaForMaskedLM
,
load_roberta_pt_weights_in_tf2
,
RobertaForMaskedLM
,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'roberta'
:
(
RobertaConfig
,
TFRobertaForMaskedLM
,
RobertaForMaskedLM
,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'roberta-large-mnli'
:
(
RobertaConfig
,
TFRobertaForSequenceClassification
,
load_roberta_pt_weights_in_tf2
,
RobertaForSequenceClassification
,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'roberta-large-mnli'
:
(
RobertaConfig
,
TFRobertaForSequenceClassification
,
RobertaForSequenceClassification
,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'distilbert'
:
(
DistilBertConfig
,
TFDistilBertForMaskedLM
,
load_distilbert_pt_weights_in_tf2
,
DistilBertForMaskedLM
,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'distilbert'
:
(
DistilBertConfig
,
TFDistilBertForMaskedLM
,
DistilBertForMaskedLM
,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'distilbert-base-uncased-distilled-squad'
:
(
DistilBertConfig
,
TFDistilBertForQuestionAnswering
,
load_distilbert_pt_weights_in_tf2
,
DistilBertForQuestionAnswering
,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'distilbert-base-uncased-distilled-squad'
:
(
DistilBertConfig
,
TFDistilBertForQuestionAnswering
,
DistilBertForQuestionAnswering
,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'ctrl'
:
(
CTRLConfig
,
TFCTRLLMHeadModel
,
load_ctrl_pt_weights_in_tf2
,
CTRLLMHeadModel
,
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
,
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP
)
'ctrl'
:
(
CTRLConfig
,
TFCTRLLMHeadModel
,
CTRLLMHeadModel
,
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
,
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP
)
}
}
def
convert_pt_checkpoint_to_tf
(
model_type
,
pytorch_checkpoint_path
,
config_file
,
tf_dump_path
,
compare_with_pt_model
=
False
,
use_cached_models
=
True
):
def
convert_pt_checkpoint_to_tf
(
model_type
,
pytorch_checkpoint_path
,
config_file
,
tf_dump_path
,
compare_with_pt_model
=
False
,
use_cached_models
=
True
):
if
model_type
not
in
MODEL_CLASSES
:
if
model_type
not
in
MODEL_CLASSES
:
raise
ValueError
(
"Unrecognized model type, should be one of {}."
.
format
(
list
(
MODEL_CLASSES
.
keys
())))
raise
ValueError
(
"Unrecognized model type, should be one of {}."
.
format
(
list
(
MODEL_CLASSES
.
keys
())))
config_class
,
model_class
,
loading_fct
,
pt_model_class
,
aws_model_maps
,
aws_config_map
=
MODEL_CLASSES
[
model_type
]
config_class
,
model_class
,
pt_model_class
,
aws_model_maps
,
aws_config_map
=
MODEL_CLASSES
[
model_type
]
# Initialise TF model
# Initialise TF model
if
config_file
in
aws_config_map
:
if
config_file
in
aws_config_map
:
...
@@ -105,7 +106,8 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
...
@@ -105,7 +106,8 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
# Load weights from tf checkpoint
# Load weights from tf checkpoint
if
pytorch_checkpoint_path
in
aws_model_maps
:
if
pytorch_checkpoint_path
in
aws_model_maps
:
pytorch_checkpoint_path
=
cached_path
(
aws_model_maps
[
pytorch_checkpoint_path
],
force_download
=
not
use_cached_models
)
pytorch_checkpoint_path
=
cached_path
(
aws_model_maps
[
pytorch_checkpoint_path
],
force_download
=
not
use_cached_models
)
tf_model
=
loading_fct
(
tf_model
,
pytorch_checkpoint_path
)
# Load PyTorch checkpoint in tf2 model:
tf_model
=
load_pytorch_checkpoint_in_tf2_model
(
tf_model
,
pytorch_checkpoint_path
)
if
compare_with_pt_model
:
if
compare_with_pt_model
:
inputs_list
=
[[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]]
inputs_list
=
[[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]]
...
@@ -147,7 +149,7 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc
...
@@ -147,7 +149,7 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc
if
model_type
not
in
MODEL_CLASSES
:
if
model_type
not
in
MODEL_CLASSES
:
raise
ValueError
(
"Unrecognized model type {}, should be one of {}."
.
format
(
model_type
,
list
(
MODEL_CLASSES
.
keys
())))
raise
ValueError
(
"Unrecognized model type {}, should be one of {}."
.
format
(
model_type
,
list
(
MODEL_CLASSES
.
keys
())))
config_class
,
model_class
,
loading_fct
,
pt_model_class
,
aws_model_maps
,
aws_config_map
=
MODEL_CLASSES
[
model_type
]
config_class
,
model_class
,
pt_model_class
,
aws_model_maps
,
aws_config_map
=
MODEL_CLASSES
[
model_type
]
if
model_shortcut_names_or_path
is
None
:
if
model_shortcut_names_or_path
is
None
:
model_shortcut_names_or_path
=
list
(
aws_model_maps
.
keys
())
model_shortcut_names_or_path
=
list
(
aws_model_maps
.
keys
())
...
...
transformers/modeling_tf_bert.py
View file @
da26bae6
...
@@ -30,7 +30,6 @@ import tensorflow as tf
...
@@ -30,7 +30,6 @@ import tensorflow as tf
from
.configuration_bert
import
BertConfig
from
.configuration_bert
import
BertConfig
from
.modeling_tf_utils
import
TFPreTrainedModel
,
get_initializer
from
.modeling_tf_utils
import
TFPreTrainedModel
,
get_initializer
from
.file_utils
import
add_start_docstrings
from
.file_utils
import
add_start_docstrings
from
.modeling_tf_pytorch_utils
import
load_pytorch_checkpoint_in_tf2_model
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -52,14 +51,6 @@ TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
...
@@ -52,14 +51,6 @@ TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
}
}
def
load_bert_pt_weights_in_tf2
(
tf_model
,
pytorch_checkpoint_path
):
# build the network
inputs_list
=
[[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]]
tf_inputs
=
tf
.
constant
(
inputs_list
)
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
return
load_pytorch_checkpoint_in_tf2_model
(
tf_model
,
pytorch_checkpoint_path
,
tf_inputs
=
tf_inputs
)
def
gelu
(
x
):
def
gelu
(
x
):
""" Gaussian Error Linear Unit.
""" Gaussian Error Linear Unit.
Original Implementation of the gelu activation function in Google Bert repo when initially created.
Original Implementation of the gelu activation function in Google Bert repo when initially created.
...
@@ -545,7 +536,6 @@ class TFBertPreTrainedModel(TFPreTrainedModel):
...
@@ -545,7 +536,6 @@ class TFBertPreTrainedModel(TFPreTrainedModel):
"""
"""
config_class
=
BertConfig
config_class
=
BertConfig
pretrained_model_archive_map
=
TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_model_archive_map
=
TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP
load_pt_weights
=
load_bert_pt_weights_in_tf2
base_model_prefix
=
"bert"
base_model_prefix
=
"bert"
...
...
transformers/modeling_tf_ctrl.py
View file @
da26bae6
...
@@ -27,20 +27,11 @@ import tensorflow as tf
...
@@ -27,20 +27,11 @@ import tensorflow as tf
from
.configuration_ctrl
import
CTRLConfig
from
.configuration_ctrl
import
CTRLConfig
from
.modeling_tf_utils
import
TFPreTrainedModel
,
get_initializer
,
shape_list
,
TFSharedEmbeddings
from
.modeling_tf_utils
import
TFPreTrainedModel
,
get_initializer
,
shape_list
,
TFSharedEmbeddings
from
.file_utils
import
add_start_docstrings
from
.file_utils
import
add_start_docstrings
from
.modeling_tf_pytorch_utils
import
load_pytorch_checkpoint_in_tf2_model
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
=
{
"ctrl"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/ctrl-tf_model.h5"
}
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
=
{
"ctrl"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/ctrl-tf_model.h5"
}
def
load_ctrl_pt_weights_in_tf2
(
tf_model
,
pytorch_checkpoint_path
):
# build the network
inputs_list
=
[[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]]
tf_inputs
=
tf
.
constant
(
inputs_list
)
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
return
load_pytorch_checkpoint_in_tf2_model
(
tf_model
,
pytorch_checkpoint_path
,
tf_inputs
=
tf_inputs
)
def
angle_defn
(
pos
,
i
,
d_model_size
):
def
angle_defn
(
pos
,
i
,
d_model_size
):
angle_rates
=
1
/
np
.
power
(
10000
,
(
2
*
(
i
//
2
))
/
np
.
float32
(
d_model_size
))
angle_rates
=
1
/
np
.
power
(
10000
,
(
2
*
(
i
//
2
))
/
np
.
float32
(
d_model_size
))
return
pos
*
angle_rates
return
pos
*
angle_rates
...
@@ -327,7 +318,6 @@ class TFCTRLPreTrainedModel(TFPreTrainedModel):
...
@@ -327,7 +318,6 @@ class TFCTRLPreTrainedModel(TFPreTrainedModel):
config_class
=
CTRLConfig
config_class
=
CTRLConfig
pretrained_model_archive_map
=
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_model_archive_map
=
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix
=
"transformer"
base_model_prefix
=
"transformer"
load_pt_weights
=
load_ctrl_pt_weights_in_tf2
CTRL_START_DOCSTRING
=
r
""" CTRL model was proposed in
CTRL_START_DOCSTRING
=
r
""" CTRL model was proposed in
...
...
transformers/modeling_tf_distilbert.py
View file @
da26bae6
...
@@ -31,7 +31,6 @@ import tensorflow as tf
...
@@ -31,7 +31,6 @@ import tensorflow as tf
from
.configuration_distilbert
import
DistilBertConfig
from
.configuration_distilbert
import
DistilBertConfig
from
.modeling_tf_utils
import
TFPreTrainedModel
,
TFSharedEmbeddings
,
shape_list
,
get_initializer
from
.modeling_tf_utils
import
TFPreTrainedModel
,
TFSharedEmbeddings
,
shape_list
,
get_initializer
from
.file_utils
import
add_start_docstrings
from
.file_utils
import
add_start_docstrings
from
.modeling_tf_pytorch_utils
import
load_pytorch_checkpoint_in_tf2_model
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -66,14 +65,6 @@ def gelu_new(x):
...
@@ -66,14 +65,6 @@ def gelu_new(x):
(
np
.
sqrt
(
2
/
np
.
pi
)
*
(
x
+
0.044715
*
tf
.
pow
(
x
,
3
)))))
(
np
.
sqrt
(
2
/
np
.
pi
)
*
(
x
+
0.044715
*
tf
.
pow
(
x
,
3
)))))
return
x
*
cdf
return
x
*
cdf
def
load_distilbert_pt_weights_in_tf2
(
tf_model
,
pytorch_checkpoint_path
):
# build the network
inputs_list
=
tf
.
constant
([[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]])
attns_list
=
tf
.
constant
([[
1
,
1
,
0
,
0
,
1
],
[
1
,
1
,
1
,
0
,
0
],
[
1
,
0
,
0
,
1
,
1
]])
tf_inputs
=
[
inputs_list
,
attns_list
]
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
return
load_pytorch_checkpoint_in_tf2_model
(
tf_model
,
pytorch_checkpoint_path
,
tf_inputs
=
tf_inputs
)
class
TFEmbeddings
(
tf
.
keras
.
layers
.
Layer
):
class
TFEmbeddings
(
tf
.
keras
.
layers
.
Layer
):
def
__init__
(
self
,
config
,
**
kwargs
):
def
__init__
(
self
,
config
,
**
kwargs
):
super
(
TFEmbeddings
,
self
).
__init__
(
**
kwargs
)
super
(
TFEmbeddings
,
self
).
__init__
(
**
kwargs
)
...
@@ -454,7 +445,6 @@ class TFDistilBertPreTrainedModel(TFPreTrainedModel):
...
@@ -454,7 +445,6 @@ class TFDistilBertPreTrainedModel(TFPreTrainedModel):
"""
"""
config_class
=
DistilBertConfig
config_class
=
DistilBertConfig
pretrained_model_archive_map
=
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_model_archive_map
=
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
load_pt_weights
=
load_distilbert_pt_weights_in_tf2
base_model_prefix
=
"distilbert"
base_model_prefix
=
"distilbert"
...
...
transformers/modeling_tf_gpt2.py
View file @
da26bae6
...
@@ -32,7 +32,6 @@ from .modeling_tf_utils import (TFPreTrainedModel, TFConv1D, TFSharedEmbeddings,
...
@@ -32,7 +32,6 @@ from .modeling_tf_utils import (TFPreTrainedModel, TFConv1D, TFSharedEmbeddings,
TFSequenceSummary
,
shape_list
,
get_initializer
)
TFSequenceSummary
,
shape_list
,
get_initializer
)
from
.configuration_gpt2
import
GPT2Config
from
.configuration_gpt2
import
GPT2Config
from
.file_utils
import
add_start_docstrings
from
.file_utils
import
add_start_docstrings
from
.modeling_tf_pytorch_utils
import
load_pytorch_checkpoint_in_tf2_model
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -42,14 +41,6 @@ TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models
...
@@ -42,14 +41,6 @@ TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models
"distilgpt2"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-tf_model.h5"
,}
"distilgpt2"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-tf_model.h5"
,}
def
load_gpt2_pt_weights_in_tf2
(
tf_model
,
pytorch_checkpoint_path
):
# build the network
inputs_list
=
[[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]]
tf_inputs
=
tf
.
constant
(
inputs_list
)
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
return
load_pytorch_checkpoint_in_tf2_model
(
tf_model
,
pytorch_checkpoint_path
,
tf_inputs
=
tf_inputs
)
def
gelu
(
x
):
def
gelu
(
x
):
"""Gaussian Error Linear Unit.
"""Gaussian Error Linear Unit.
This is a smoother version of the RELU.
This is a smoother version of the RELU.
...
@@ -350,7 +341,6 @@ class TFGPT2PreTrainedModel(TFPreTrainedModel):
...
@@ -350,7 +341,6 @@ class TFGPT2PreTrainedModel(TFPreTrainedModel):
"""
"""
config_class
=
GPT2Config
config_class
=
GPT2Config
pretrained_model_archive_map
=
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_model_archive_map
=
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
load_pt_weights
=
load_gpt2_pt_weights_in_tf2
base_model_prefix
=
"transformer"
base_model_prefix
=
"transformer"
...
...
transformers/modeling_tf_openai.py
View file @
da26bae6
...
@@ -32,21 +32,12 @@ from .modeling_tf_utils import (TFPreTrainedModel, TFConv1D, TFSharedEmbeddings,
...
@@ -32,21 +32,12 @@ from .modeling_tf_utils import (TFPreTrainedModel, TFConv1D, TFSharedEmbeddings,
TFSequenceSummary
,
shape_list
,
get_initializer
)
TFSequenceSummary
,
shape_list
,
get_initializer
)
from
.configuration_openai
import
OpenAIGPTConfig
from
.configuration_openai
import
OpenAIGPTConfig
from
.file_utils
import
add_start_docstrings
from
.file_utils
import
add_start_docstrings
from
.modeling_tf_pytorch_utils
import
load_pytorch_checkpoint_in_tf2_model
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
=
{
"openai-gpt"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-tf_model.h5"
}
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
=
{
"openai-gpt"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-tf_model.h5"
}
def
load_openai_gpt_pt_weights_in_tf2
(
tf_model
,
pytorch_checkpoint_path
):
# build the network
inputs_list
=
[[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]]
tf_inputs
=
tf
.
constant
(
inputs_list
)
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
return
load_pytorch_checkpoint_in_tf2_model
(
tf_model
,
pytorch_checkpoint_path
,
tf_inputs
=
tf_inputs
)
def
gelu
(
x
):
def
gelu
(
x
):
"""Gaussian Error Linear Unit.
"""Gaussian Error Linear Unit.
This is a smoother version of the RELU.
This is a smoother version of the RELU.
...
@@ -335,7 +326,6 @@ class TFOpenAIGPTPreTrainedModel(TFPreTrainedModel):
...
@@ -335,7 +326,6 @@ class TFOpenAIGPTPreTrainedModel(TFPreTrainedModel):
"""
"""
config_class
=
OpenAIGPTConfig
config_class
=
OpenAIGPTConfig
pretrained_model_archive_map
=
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_model_archive_map
=
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
load_pt_weights
=
load_openai_gpt_pt_weights_in_tf2
base_model_prefix
=
"transformer"
base_model_prefix
=
"transformer"
...
...
transformers/modeling_tf_pytorch_utils.py
View file @
da26bae6
...
@@ -25,8 +25,6 @@ import numpy
...
@@ -25,8 +25,6 @@ import numpy
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
DUMMY_INPUTS
=
[[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]]
def
convert_tf_weight_name_to_pt_weight_name
(
tf_name
,
start_prefix_to_remove
=
''
):
def
convert_tf_weight_name_to_pt_weight_name
(
tf_name
,
start_prefix_to_remove
=
''
):
""" Convert a TF 2.0 model variable name in a pytorch model weight name.
""" Convert a TF 2.0 model variable name in a pytorch model weight name.
...
@@ -105,7 +103,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
...
@@ -105,7 +103,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
raise
e
raise
e
if
tf_inputs
is
None
:
if
tf_inputs
is
None
:
tf_inputs
=
tf
.
constant
(
DUMMY_INPUTS
)
tf_inputs
=
tf
_model
.
dummy_inputs
if
tf_inputs
is
not
None
:
if
tf_inputs
is
not
None
:
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
# Make sure model is built
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
# Make sure model is built
...
...
transformers/modeling_tf_roberta.py
View file @
da26bae6
...
@@ -26,7 +26,6 @@ import tensorflow as tf
...
@@ -26,7 +26,6 @@ import tensorflow as tf
from
.configuration_roberta
import
RobertaConfig
from
.configuration_roberta
import
RobertaConfig
from
.modeling_tf_utils
import
TFPreTrainedModel
,
get_initializer
from
.modeling_tf_utils
import
TFPreTrainedModel
,
get_initializer
from
.file_utils
import
add_start_docstrings
from
.file_utils
import
add_start_docstrings
from
.modeling_tf_pytorch_utils
import
load_pytorch_checkpoint_in_tf2_model
from
.modeling_tf_bert
import
TFBertEmbeddings
,
TFBertMainLayer
,
gelu
,
gelu_new
from
.modeling_tf_bert
import
TFBertEmbeddings
,
TFBertMainLayer
,
gelu
,
gelu_new
...
@@ -38,14 +37,6 @@ TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP = {
...
@@ -38,14 +37,6 @@ TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP = {
'roberta-large-mnli'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-tf_model.h5"
,
'roberta-large-mnli'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-tf_model.h5"
,
}
}
def
load_roberta_pt_weights_in_tf2
(
tf_model
,
pytorch_checkpoint_path
):
# build the network
inputs_list
=
[[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]]
tf_inputs
=
tf
.
constant
(
inputs_list
)
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
return
load_pytorch_checkpoint_in_tf2_model
(
tf_model
,
pytorch_checkpoint_path
,
tf_inputs
=
tf_inputs
)
class
TFRobertaEmbeddings
(
TFBertEmbeddings
):
class
TFRobertaEmbeddings
(
TFBertEmbeddings
):
"""
"""
Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
...
@@ -96,7 +87,6 @@ class TFRobertaPreTrainedModel(TFPreTrainedModel):
...
@@ -96,7 +87,6 @@ class TFRobertaPreTrainedModel(TFPreTrainedModel):
"""
"""
config_class
=
RobertaConfig
config_class
=
RobertaConfig
pretrained_model_archive_map
=
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_model_archive_map
=
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
load_pt_weights
=
load_roberta_pt_weights_in_tf2
base_model_prefix
=
"roberta"
base_model_prefix
=
"roberta"
...
...
transformers/modeling_tf_transfo_xl.py
View file @
da26bae6
...
@@ -33,7 +33,6 @@ from .configuration_transfo_xl import TransfoXLConfig
...
@@ -33,7 +33,6 @@ from .configuration_transfo_xl import TransfoXLConfig
from
.modeling_tf_utils
import
TFPreTrainedModel
,
TFConv1D
,
TFSequenceSummary
,
shape_list
,
get_initializer
from
.modeling_tf_utils
import
TFPreTrainedModel
,
TFConv1D
,
TFSequenceSummary
,
shape_list
,
get_initializer
from
.modeling_tf_transfo_xl_utilities
import
TFAdaptiveSoftmaxMask
from
.modeling_tf_transfo_xl_utilities
import
TFAdaptiveSoftmaxMask
from
.file_utils
import
add_start_docstrings
from
.file_utils
import
add_start_docstrings
from
.modeling_tf_pytorch_utils
import
load_pytorch_checkpoint_in_tf2_model
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -41,14 +40,6 @@ TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP = {
...
@@ -41,14 +40,6 @@ TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP = {
'transfo-xl-wt103'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-tf_model.h5"
,
'transfo-xl-wt103'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-tf_model.h5"
,
}
}
def
load_transfo_xl_pt_weights_in_tf2
(
tf_model
,
pytorch_checkpoint_path
):
# build the network
inputs_list
=
[[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]]
tf_inputs
=
tf
.
constant
(
inputs_list
)
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
return
load_pytorch_checkpoint_in_tf2_model
(
tf_model
,
pytorch_checkpoint_path
,
tf_inputs
=
tf_inputs
)
class
TFPositionalEmbedding
(
tf
.
keras
.
layers
.
Layer
):
class
TFPositionalEmbedding
(
tf
.
keras
.
layers
.
Layer
):
def
__init__
(
self
,
demb
,
**
kwargs
):
def
__init__
(
self
,
demb
,
**
kwargs
):
super
(
TFPositionalEmbedding
,
self
).
__init__
(
**
kwargs
)
super
(
TFPositionalEmbedding
,
self
).
__init__
(
**
kwargs
)
...
@@ -577,7 +568,6 @@ class TFTransfoXLPreTrainedModel(TFPreTrainedModel):
...
@@ -577,7 +568,6 @@ class TFTransfoXLPreTrainedModel(TFPreTrainedModel):
"""
"""
config_class
=
TransfoXLConfig
config_class
=
TransfoXLConfig
pretrained_model_archive_map
=
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_model_archive_map
=
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
load_pt_weights
=
load_transfo_xl_pt_weights_in_tf2
base_model_prefix
=
"transformer"
base_model_prefix
=
"transformer"
...
...
transformers/modeling_tf_utils.py
View file @
da26bae6
...
@@ -25,9 +25,11 @@ import tensorflow as tf
...
@@ -25,9 +25,11 @@ import tensorflow as tf
from
.configuration_utils
import
PretrainedConfig
from
.configuration_utils
import
PretrainedConfig
from
.file_utils
import
cached_path
,
WEIGHTS_NAME
,
TF_WEIGHTS_NAME
,
TF2_WEIGHTS_NAME
from
.file_utils
import
cached_path
,
WEIGHTS_NAME
,
TF_WEIGHTS_NAME
,
TF2_WEIGHTS_NAME
from
.modeling_tf_pytorch_utils
import
load_pytorch_checkpoint_in_tf2_model
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
DUMMY_INPUTS
=
[[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]]
class
TFPreTrainedModel
(
tf
.
keras
.
Model
):
class
TFPreTrainedModel
(
tf
.
keras
.
Model
):
r
""" Base class for all TF models.
r
""" Base class for all TF models.
...
@@ -48,8 +50,8 @@ class TFPreTrainedModel(tf.keras.Model):
...
@@ -48,8 +50,8 @@ class TFPreTrainedModel(tf.keras.Model):
"""
"""
config_class
=
None
config_class
=
None
pretrained_model_archive_map
=
{}
pretrained_model_archive_map
=
{}
load_pt_weights
=
lambda
model
,
config
,
path
:
None
base_model_prefix
=
""
base_model_prefix
=
""
dummy_inputs
=
tf
.
constant
(
DUMMY_INPUTS
)
# dummy inputs to build the network
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
super
(
TFPreTrainedModel
,
self
).
__init__
(
*
inputs
,
**
kwargs
)
super
(
TFPreTrainedModel
,
self
).
__init__
(
*
inputs
,
**
kwargs
)
...
@@ -262,17 +264,16 @@ class TFPreTrainedModel(tf.keras.Model):
...
@@ -262,17 +264,16 @@ class TFPreTrainedModel(tf.keras.Model):
if
from_pt
:
if
from_pt
:
# Load from a PyTorch checkpoint
# Load from a PyTorch checkpoint
return
cls
.
load_p
t_weights
(
model
,
resolved_archive_file
)
return
load_p
ytorch_checkpoint_in_tf2_model
(
model
,
resolved_archive_file
)
inputs
=
tf
.
constant
([[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]])
ret
=
model
(
model
.
dummy_inputs
,
training
=
False
)
# build the network with dummy inputs
ret
=
model
(
inputs
,
training
=
False
)
# build the network with dummy inputs
assert
os
.
path
.
isfile
(
resolved_archive_file
),
"Error retrieving file {}"
.
format
(
resolved_archive_file
)
assert
os
.
path
.
isfile
(
resolved_archive_file
),
"Error retrieving file {}"
.
format
(
resolved_archive_file
)
# 'by_name' allow us to do transfer learning by skipping/adding layers
# 'by_name' allow us to do transfer learning by skipping/adding layers
# see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
# see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
model
.
load_weights
(
resolved_archive_file
,
by_name
=
True
)
model
.
load_weights
(
resolved_archive_file
,
by_name
=
True
)
ret
=
model
(
inputs
,
training
=
False
)
# Make sure restore ops are run
ret
=
model
(
model
.
dummy_
inputs
,
training
=
False
)
# Make sure restore ops are run
return
model
return
model
...
...
transformers/modeling_tf_xlm.py
View file @
da26bae6
...
@@ -25,9 +25,8 @@ import numpy as np
...
@@ -25,9 +25,8 @@ import numpy as np
import
tensorflow
as
tf
import
tensorflow
as
tf
from
.configuration_xlm
import
XLMConfig
from
.configuration_xlm
import
XLMConfig
from
.modeling_tf_utils
import
TFPreTrainedModel
,
TFSharedEmbeddings
,
TFSequenceSummary
,
shape_list
,
get_initializer
from
.modeling_tf_utils
import
TFPreTrainedModel
,
TFSharedEmbeddings
,
TFSequenceSummary
,
shape_list
,
get_initializer
,
DUMMY_INPUTS
from
.file_utils
import
add_start_docstrings
from
.file_utils
import
add_start_docstrings
from
.modeling_tf_pytorch_utils
import
load_pytorch_checkpoint_in_tf2_model
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -45,19 +44,6 @@ TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP = {
...
@@ -45,19 +44,6 @@ TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP = {
}
}
def
load_xlm_pt_weights_in_tf2
(
tf_model
,
pytorch_checkpoint_path
):
# build the network
inputs_list
=
tf
.
constant
([[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]])
attns_list
=
tf
.
constant
([[
1
,
1
,
0
,
0
,
1
],
[
1
,
1
,
1
,
0
,
0
],
[
1
,
0
,
0
,
1
,
1
]])
if
tf_model
.
config
.
use_lang_emb
and
tf_model
.
config
.
n_langs
>
1
:
langs_list
=
tf
.
constant
([[
1
,
1
,
0
,
0
,
1
],
[
1
,
1
,
1
,
0
,
0
],
[
1
,
0
,
0
,
1
,
1
]])
else
:
langs_list
=
None
tf_inputs
=
[
inputs_list
,
attns_list
,
langs_list
]
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
return
load_pytorch_checkpoint_in_tf2_model
(
tf_model
,
pytorch_checkpoint_path
,
tf_inputs
=
tf_inputs
)
def
create_sinusoidal_embeddings
(
n_pos
,
dim
,
out
):
def
create_sinusoidal_embeddings
(
n_pos
,
dim
,
out
):
position_enc
=
np
.
array
([
position_enc
=
np
.
array
([
[
pos
/
np
.
power
(
10000
,
2
*
(
j
//
2
)
/
dim
)
for
j
in
range
(
dim
)]
[
pos
/
np
.
power
(
10000
,
2
*
(
j
//
2
)
/
dim
)
for
j
in
range
(
dim
)]
...
@@ -441,9 +427,19 @@ class TFXLMPreTrainedModel(TFPreTrainedModel):
...
@@ -441,9 +427,19 @@ class TFXLMPreTrainedModel(TFPreTrainedModel):
"""
"""
config_class
=
XLMConfig
config_class
=
XLMConfig
pretrained_model_archive_map
=
TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_model_archive_map
=
TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP
load_pt_weights
=
load_xlm_pt_weights_in_tf2
base_model_prefix
=
"transformer"
base_model_prefix
=
"transformer"
@
property
def
dummy_inputs
(
self
):
# Sometimes XLM has language embeddings so don't forget to build them as well if needed
inputs_list
=
tf
.
constant
([[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]])
attns_list
=
tf
.
constant
([[
1
,
1
,
0
,
0
,
1
],
[
1
,
1
,
1
,
0
,
0
],
[
1
,
0
,
0
,
1
,
1
]])
if
self
.
config
.
use_lang_emb
and
self
.
config
.
n_langs
>
1
:
langs_list
=
tf
.
constant
([[
1
,
1
,
0
,
0
,
1
],
[
1
,
1
,
1
,
0
,
0
],
[
1
,
0
,
0
,
1
,
1
]])
else
:
langs_list
=
None
return
[
inputs_list
,
attns_list
,
langs_list
]
XLM_START_DOCSTRING
=
r
""" The XLM model was proposed in
XLM_START_DOCSTRING
=
r
""" The XLM model was proposed in
`Cross-lingual Language Model Pretraining`_
`Cross-lingual Language Model Pretraining`_
...
...
transformers/modeling_tf_xlnet.py
View file @
da26bae6
...
@@ -30,7 +30,6 @@ import tensorflow as tf
...
@@ -30,7 +30,6 @@ import tensorflow as tf
from
.configuration_xlnet
import
XLNetConfig
from
.configuration_xlnet
import
XLNetConfig
from
.modeling_tf_utils
import
TFPreTrainedModel
,
TFSharedEmbeddings
,
TFSequenceSummary
,
shape_list
,
get_initializer
from
.modeling_tf_utils
import
TFPreTrainedModel
,
TFSharedEmbeddings
,
TFSequenceSummary
,
shape_list
,
get_initializer
from
.file_utils
import
add_start_docstrings
from
.file_utils
import
add_start_docstrings
from
.modeling_tf_pytorch_utils
import
load_pytorch_checkpoint_in_tf2_model
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -41,13 +40,6 @@ TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP = {
...
@@ -41,13 +40,6 @@ TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP = {
}
}
def
load_xlnet_pt_weights_in_tf2
(
tf_model
,
pytorch_checkpoint_path
):
inputs_list
=
[[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]]
tf_inputs
=
tf
.
constant
(
inputs_list
)
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
# build the network
return
load_pytorch_checkpoint_in_tf2_model
(
tf_model
,
pytorch_checkpoint_path
,
tf_inputs
=
tf_inputs
)
def
gelu
(
x
):
def
gelu
(
x
):
""" Implementation of the gelu activation function.
""" Implementation of the gelu activation function.
XLNet is using OpenAI GPT's gelu
XLNet is using OpenAI GPT's gelu
...
@@ -670,7 +662,6 @@ class TFXLNetPreTrainedModel(TFPreTrainedModel):
...
@@ -670,7 +662,6 @@ class TFXLNetPreTrainedModel(TFPreTrainedModel):
"""
"""
config_class
=
XLNetConfig
config_class
=
XLNetConfig
pretrained_model_archive_map
=
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_model_archive_map
=
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
load_pt_weights
=
load_xlnet_pt_weights_in_tf2
base_model_prefix
=
"transformer"
base_model_prefix
=
"transformer"
...
...
transformers/tests/modeling_common_test.py
View file @
da26bae6
...
@@ -17,8 +17,10 @@ from __future__ import division
...
@@ -17,8 +17,10 @@ from __future__ import division
from
__future__
import
print_function
from
__future__
import
print_function
import
copy
import
copy
import
sys
import
os
import
os
import
shutil
import
shutil
import
tempfile
import
json
import
json
import
random
import
random
import
uuid
import
uuid
...
@@ -31,6 +33,7 @@ from transformers import is_torch_available
...
@@ -31,6 +33,7 @@ from transformers import is_torch_available
if
is_torch_available
():
if
is_torch_available
():
import
torch
import
torch
import
numpy
as
np
from
transformers
import
(
PretrainedConfig
,
PreTrainedModel
,
from
transformers
import
(
PretrainedConfig
,
PreTrainedModel
,
BertModel
,
BertConfig
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
BertModel
,
BertConfig
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
...
@@ -38,6 +41,20 @@ if is_torch_available():
...
@@ -38,6 +41,20 @@ if is_torch_available():
else
:
else
:
pytestmark
=
pytest
.
mark
.
skip
(
"Require Torch"
)
pytestmark
=
pytest
.
mark
.
skip
(
"Require Torch"
)
if
sys
.
version_info
[
0
]
==
2
:
import
cPickle
as
pickle
class
TemporaryDirectory
(
object
):
"""Context manager for tempfile.mkdtemp() so it's usable with "with" statement."""
def
__enter__
(
self
):
self
.
name
=
tempfile
.
mkdtemp
()
return
self
.
name
def
__exit__
(
self
,
exc_type
,
exc_value
,
traceback
):
shutil
.
rmtree
(
self
.
name
)
else
:
import
pickle
TemporaryDirectory
=
tempfile
.
TemporaryDirectory
unicode
=
str
def
_config_zero_init
(
config
):
def
_config_zero_init
(
config
):
configs_no_init
=
copy
.
deepcopy
(
config
)
configs_no_init
=
copy
.
deepcopy
(
config
)
...
@@ -57,6 +74,23 @@ class CommonTestCases:
...
@@ -57,6 +74,23 @@ class CommonTestCases:
test_resize_embeddings
=
True
test_resize_embeddings
=
True
test_head_masking
=
True
test_head_masking
=
True
def
test_save_load
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
for
model_class
in
self
.
all_model_classes
:
model
=
model_class
(
config
)
model
.
eval
()
with
torch
.
no_grad
():
outputs
=
model
(
**
inputs_dict
)
with
TemporaryDirectory
()
as
tmpdirname
:
model
.
save_pretrained
(
tmpdirname
)
model
=
model_class
.
from_pretrained
(
tmpdirname
)
with
torch
.
no_grad
():
after_outputs
=
model
(
**
inputs_dict
)
max_diff
=
np
.
amax
(
np
.
abs
(
after_outputs
[
0
].
numpy
()
-
outputs
[
0
].
numpy
()))
self
.
assertLessEqual
(
max_diff
,
1e-5
)
def
test_initialization
(
self
):
def
test_initialization
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment