Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
da26bae6
Commit
da26bae6
authored
Oct 10, 2019
by
thomwolf
Browse files
adding more tests on TF and pytorch serialization - updating configuration for better serialization
parent
bb04edb4
Changes
15
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
90 additions
and
148 deletions
+90
-148
transformers/__init__.py
transformers/__init__.py
+7
-17
transformers/configuration_utils.py
transformers/configuration_utils.py
+2
-2
transformers/convert_pytorch_checkpoint_to_tf2.py
transformers/convert_pytorch_checkpoint_to_tf2.py
+28
-26
transformers/modeling_tf_bert.py
transformers/modeling_tf_bert.py
+0
-10
transformers/modeling_tf_ctrl.py
transformers/modeling_tf_ctrl.py
+0
-10
transformers/modeling_tf_distilbert.py
transformers/modeling_tf_distilbert.py
+0
-10
transformers/modeling_tf_gpt2.py
transformers/modeling_tf_gpt2.py
+0
-10
transformers/modeling_tf_openai.py
transformers/modeling_tf_openai.py
+0
-10
transformers/modeling_tf_pytorch_utils.py
transformers/modeling_tf_pytorch_utils.py
+1
-3
transformers/modeling_tf_roberta.py
transformers/modeling_tf_roberta.py
+0
-10
transformers/modeling_tf_transfo_xl.py
transformers/modeling_tf_transfo_xl.py
+0
-10
transformers/modeling_tf_utils.py
transformers/modeling_tf_utils.py
+6
-5
transformers/modeling_tf_xlm.py
transformers/modeling_tf_xlm.py
+12
-16
transformers/modeling_tf_xlnet.py
transformers/modeling_tf_xlnet.py
+0
-9
transformers/tests/modeling_common_test.py
transformers/tests/modeling_common_test.py
+34
-0
No files found.
transformers/__init__.py
View file @
da26bae6
...
...
@@ -110,59 +110,49 @@ if is_tf_available():
TFBertForMaskedLM
,
TFBertForNextSentencePrediction
,
TFBertForSequenceClassification
,
TFBertForMultipleChoice
,
TFBertForTokenClassification
,
TFBertForQuestionAnswering
,
load_bert_pt_weights_in_tf2
,
TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_tf_gpt2
import
(
TFGPT2PreTrainedModel
,
TFGPT2MainLayer
,
TFGPT2Model
,
TFGPT2LMHeadModel
,
TFGPT2DoubleHeadsModel
,
load_gpt2_pt_weights_in_tf2
,
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_tf_openai
import
(
TFOpenAIGPTPreTrainedModel
,
TFOpenAIGPTMainLayer
,
TFOpenAIGPTModel
,
TFOpenAIGPTLMHeadModel
,
TFOpenAIGPTDoubleHeadsModel
,
load_openai_gpt_pt_weights_in_tf2
,
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_tf_transfo_xl
import
(
TFTransfoXLPreTrainedModel
,
TFTransfoXLMainLayer
,
TFTransfoXLModel
,
TFTransfoXLLMHeadModel
,
load_transfo_xl_pt_weights_in_tf2
,
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_tf_xlnet
import
(
TFXLNetPreTrainedModel
,
TFXLNetMainLayer
,
TFXLNetModel
,
TFXLNetLMHeadModel
,
TFXLNetForSequenceClassification
,
TFXLNetForQuestionAnsweringSimple
,
load_xlnet_pt_weights_in_tf2
,
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_tf_xlm
import
(
TFXLMPreTrainedModel
,
TFXLMMainLayer
,
TFXLMModel
,
TFXLMWithLMHeadModel
,
TFXLMForSequenceClassification
,
TFXLMForQuestionAnsweringSimple
,
load_xlm_pt_weights_in_tf2
,
TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_tf_roberta
import
(
TFRobertaPreTrainedModel
,
TFRobertaMainLayer
,
TFRobertaModel
,
TFRobertaForMaskedLM
,
TFRobertaForSequenceClassification
,
load_roberta_pt_weights_in_tf2
,
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_tf_distilbert
import
(
TFDistilBertPreTrainedModel
,
TFDistilBertMainLayer
,
TFDistilBertModel
,
TFDistilBertForMaskedLM
,
TFDistilBertForSequenceClassification
,
TFDistilBertForQuestionAnswering
,
load_distilbert_pt_weights_in_tf2
,
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_tf_ctrl
import
(
TFCTRLPreTrainedModel
,
TFCTRLModel
,
TFCTRLLMHeadModel
,
load_ctrl_pt_weights_in_tf2
,
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
)
# TF 2.0 <=> PyTorch conversion utilities
if
is_tf_available
()
and
is_torch_available
():
from
.modeling_tf_pytorch_utils
import
(
convert_tf_weight_name_to_pt_weight_name
,
from
.modeling_tf_pytorch_utils
import
(
convert_tf_weight_name_to_pt_weight_name
,
load_pytorch_checkpoint_in_tf2_model
,
load_pytorch_weights_in_tf2_model
,
load_pytorch_model_in_tf2_model
,
...
...
transformers/configuration_utils.py
View file @
da26bae6
...
...
@@ -153,7 +153,7 @@ class PretrainedConfig(object):
config
=
cls
.
from_json_file
(
resolved_config_file
)
if
hasattr
(
config
,
'pruned_heads'
):
config
.
pruned_heads
=
dict
((
int
(
key
),
set
(
value
)
)
for
key
,
value
in
config
.
pruned_heads
.
items
())
config
.
pruned_heads
=
dict
((
int
(
key
),
value
)
for
key
,
value
in
config
.
pruned_heads
.
items
())
# Update config with kwargs if needed
to_remove
=
[]
...
...
@@ -164,7 +164,7 @@ class PretrainedConfig(object):
for
key
in
to_remove
:
kwargs
.
pop
(
key
,
None
)
logger
.
info
(
"Model config %s"
,
config
)
logger
.
info
(
"Model config %s"
,
str
(
config
)
)
if
return_unused_kwargs
:
return
config
,
kwargs
else
:
...
...
transformers/convert_pytorch_checkpoint_to_tf2.py
View file @
da26bae6
...
...
@@ -24,15 +24,16 @@ import tensorflow as tf
from
transformers
import
is_torch_available
,
cached_path
from
transformers
import
(
BertConfig
,
TFBertForPreTraining
,
TFBertForQuestionAnswering
,
TFBertForSequenceClassification
,
load_bert_pt_weights_in_tf2
,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
GPT2Config
,
TFGPT2LMHeadModel
,
load_gpt2_pt_weights_in_tf2
,
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
,
XLNetConfig
,
TFXLNetLMHeadModel
,
load_xlnet_pt_weights_in_tf2
,
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
,
XLMConfig
,
TFXLMWithLMHeadModel
,
load_xlm_pt_weights_in_tf2
,
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
,
TransfoXLConfig
,
TFTransfoXLLMHeadModel
,
load_transfo_xl_pt_weights_in_tf2
,
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP
,
OpenAIGPTConfig
,
TFOpenAIGPTLMHeadModel
,
load_openai_gpt_pt_weights_in_tf2
,
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
RobertaConfig
,
TFRobertaForMaskedLM
,
TFRobertaForSequenceClassification
,
load_roberta_pt_weights_in_tf2
,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
,
DistilBertConfig
,
TFDistilBertForMaskedLM
,
TFDistilBertForQuestionAnswering
,
load_distilbert_pt_weights_in_tf2
,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
CTRLConfig
,
TFCTRLLMHeadModel
,
load_ctrl_pt_weights_in_tf2
,
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP
)
from
transformers
import
(
load_pytorch_checkpoint_in_tf2_model
,
BertConfig
,
TFBertForPreTraining
,
TFBertForQuestionAnswering
,
TFBertForSequenceClassification
,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
GPT2Config
,
TFGPT2LMHeadModel
,
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
,
XLNetConfig
,
TFXLNetLMHeadModel
,
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
,
XLMConfig
,
TFXLMWithLMHeadModel
,
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
,
TransfoXLConfig
,
TFTransfoXLLMHeadModel
,
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP
,
OpenAIGPTConfig
,
TFOpenAIGPTLMHeadModel
,
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
RobertaConfig
,
TFRobertaForMaskedLM
,
TFRobertaForSequenceClassification
,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
,
DistilBertConfig
,
TFDistilBertForMaskedLM
,
TFDistilBertForQuestionAnswering
,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
CTRLConfig
,
TFCTRLLMHeadModel
,
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP
)
if
is_torch_available
():
import
torch
...
...
@@ -71,27 +72,27 @@ import logging
logging
.
basicConfig
(
level
=
logging
.
INFO
)
MODEL_CLASSES
=
{
'bert'
:
(
BertConfig
,
TFBertForPreTraining
,
load_bert_pt_weights_in_tf2
,
BertForPreTraining
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'bert-large-uncased-whole-word-masking-finetuned-squad'
:
(
BertConfig
,
TFBertForQuestionAnswering
,
load_bert_pt_weights_in_tf2
,
BertForQuestionAnswering
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'bert-large-cased-whole-word-masking-finetuned-squad'
:
(
BertConfig
,
TFBertForQuestionAnswering
,
load_bert_pt_weights_in_tf2
,
BertForQuestionAnswering
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'bert-base-cased-finetuned-mrpc'
:
(
BertConfig
,
TFBertForSequenceClassification
,
load_bert_pt_weights_in_tf2
,
BertForSequenceClassification
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'gpt2'
:
(
GPT2Config
,
TFGPT2LMHeadModel
,
load_gpt2_pt_weights_in_tf2
,
GPT2LMHeadModel
,
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
,
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'xlnet'
:
(
XLNetConfig
,
TFXLNetLMHeadModel
,
load_xlnet_pt_weights_in_tf2
,
XLNetLMHeadModel
,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
,
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'xlm'
:
(
XLMConfig
,
TFXLMWithLMHeadModel
,
load_xlm_pt_weights_in_tf2
,
XLMWithLMHeadModel
,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP
,
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'transfo-xl'
:
(
TransfoXLConfig
,
TFTransfoXLLMHeadModel
,
load_transfo_xl_pt_weights_in_tf2
,
TransfoXLLMHeadModel
,
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
,
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'openai-gpt'
:
(
OpenAIGPTConfig
,
TFOpenAIGPTLMHeadModel
,
load_openai_gpt_pt_weights_in_tf2
,
OpenAIGPTLMHeadModel
,
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
,
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'roberta'
:
(
RobertaConfig
,
TFRobertaForMaskedLM
,
load_roberta_pt_weights_in_tf2
,
RobertaForMaskedLM
,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'roberta-large-mnli'
:
(
RobertaConfig
,
TFRobertaForSequenceClassification
,
load_roberta_pt_weights_in_tf2
,
RobertaForSequenceClassification
,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'distilbert'
:
(
DistilBertConfig
,
TFDistilBertForMaskedLM
,
load_distilbert_pt_weights_in_tf2
,
DistilBertForMaskedLM
,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'distilbert-base-uncased-distilled-squad'
:
(
DistilBertConfig
,
TFDistilBertForQuestionAnswering
,
load_distilbert_pt_weights_in_tf2
,
DistilBertForQuestionAnswering
,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'ctrl'
:
(
CTRLConfig
,
TFCTRLLMHeadModel
,
load_ctrl_pt_weights_in_tf2
,
CTRLLMHeadModel
,
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
,
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP
)
'bert'
:
(
BertConfig
,
TFBertForPreTraining
,
BertForPreTraining
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'bert-large-uncased-whole-word-masking-finetuned-squad'
:
(
BertConfig
,
TFBertForQuestionAnswering
,
BertForQuestionAnswering
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'bert-large-cased-whole-word-masking-finetuned-squad'
:
(
BertConfig
,
TFBertForQuestionAnswering
,
BertForQuestionAnswering
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'bert-base-cased-finetuned-mrpc'
:
(
BertConfig
,
TFBertForSequenceClassification
,
BertForSequenceClassification
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'gpt2'
:
(
GPT2Config
,
TFGPT2LMHeadModel
,
GPT2LMHeadModel
,
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
,
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'xlnet'
:
(
XLNetConfig
,
TFXLNetLMHeadModel
,
XLNetLMHeadModel
,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
,
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'xlm'
:
(
XLMConfig
,
TFXLMWithLMHeadModel
,
XLMWithLMHeadModel
,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP
,
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'transfo-xl'
:
(
TransfoXLConfig
,
TFTransfoXLLMHeadModel
,
TransfoXLLMHeadModel
,
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
,
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'openai-gpt'
:
(
OpenAIGPTConfig
,
TFOpenAIGPTLMHeadModel
,
OpenAIGPTLMHeadModel
,
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
,
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'roberta'
:
(
RobertaConfig
,
TFRobertaForMaskedLM
,
RobertaForMaskedLM
,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'roberta-large-mnli'
:
(
RobertaConfig
,
TFRobertaForSequenceClassification
,
RobertaForSequenceClassification
,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'distilbert'
:
(
DistilBertConfig
,
TFDistilBertForMaskedLM
,
DistilBertForMaskedLM
,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'distilbert-base-uncased-distilled-squad'
:
(
DistilBertConfig
,
TFDistilBertForQuestionAnswering
,
DistilBertForQuestionAnswering
,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'ctrl'
:
(
CTRLConfig
,
TFCTRLLMHeadModel
,
CTRLLMHeadModel
,
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
,
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP
)
}
def
convert_pt_checkpoint_to_tf
(
model_type
,
pytorch_checkpoint_path
,
config_file
,
tf_dump_path
,
compare_with_pt_model
=
False
,
use_cached_models
=
True
):
if
model_type
not
in
MODEL_CLASSES
:
raise
ValueError
(
"Unrecognized model type, should be one of {}."
.
format
(
list
(
MODEL_CLASSES
.
keys
())))
config_class
,
model_class
,
loading_fct
,
pt_model_class
,
aws_model_maps
,
aws_config_map
=
MODEL_CLASSES
[
model_type
]
config_class
,
model_class
,
pt_model_class
,
aws_model_maps
,
aws_config_map
=
MODEL_CLASSES
[
model_type
]
# Initialise TF model
if
config_file
in
aws_config_map
:
...
...
@@ -105,7 +106,8 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
# Load weights from tf checkpoint
if
pytorch_checkpoint_path
in
aws_model_maps
:
pytorch_checkpoint_path
=
cached_path
(
aws_model_maps
[
pytorch_checkpoint_path
],
force_download
=
not
use_cached_models
)
tf_model
=
loading_fct
(
tf_model
,
pytorch_checkpoint_path
)
# Load PyTorch checkpoint in tf2 model:
tf_model
=
load_pytorch_checkpoint_in_tf2_model
(
tf_model
,
pytorch_checkpoint_path
)
if
compare_with_pt_model
:
inputs_list
=
[[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]]
...
...
@@ -147,7 +149,7 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc
if
model_type
not
in
MODEL_CLASSES
:
raise
ValueError
(
"Unrecognized model type {}, should be one of {}."
.
format
(
model_type
,
list
(
MODEL_CLASSES
.
keys
())))
config_class
,
model_class
,
loading_fct
,
pt_model_class
,
aws_model_maps
,
aws_config_map
=
MODEL_CLASSES
[
model_type
]
config_class
,
model_class
,
pt_model_class
,
aws_model_maps
,
aws_config_map
=
MODEL_CLASSES
[
model_type
]
if
model_shortcut_names_or_path
is
None
:
model_shortcut_names_or_path
=
list
(
aws_model_maps
.
keys
())
...
...
transformers/modeling_tf_bert.py
View file @
da26bae6
...
...
@@ -30,7 +30,6 @@ import tensorflow as tf
from
.configuration_bert
import
BertConfig
from
.modeling_tf_utils
import
TFPreTrainedModel
,
get_initializer
from
.file_utils
import
add_start_docstrings
from
.modeling_tf_pytorch_utils
import
load_pytorch_checkpoint_in_tf2_model
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -52,14 +51,6 @@ TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
}
def
load_bert_pt_weights_in_tf2
(
tf_model
,
pytorch_checkpoint_path
):
# build the network
inputs_list
=
[[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]]
tf_inputs
=
tf
.
constant
(
inputs_list
)
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
return
load_pytorch_checkpoint_in_tf2_model
(
tf_model
,
pytorch_checkpoint_path
,
tf_inputs
=
tf_inputs
)
def
gelu
(
x
):
""" Gaussian Error Linear Unit.
Original Implementation of the gelu activation function in Google Bert repo when initially created.
...
...
@@ -545,7 +536,6 @@ class TFBertPreTrainedModel(TFPreTrainedModel):
"""
config_class
=
BertConfig
pretrained_model_archive_map
=
TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP
load_pt_weights
=
load_bert_pt_weights_in_tf2
base_model_prefix
=
"bert"
...
...
transformers/modeling_tf_ctrl.py
View file @
da26bae6
...
...
@@ -27,20 +27,11 @@ import tensorflow as tf
from
.configuration_ctrl
import
CTRLConfig
from
.modeling_tf_utils
import
TFPreTrainedModel
,
get_initializer
,
shape_list
,
TFSharedEmbeddings
from
.file_utils
import
add_start_docstrings
from
.modeling_tf_pytorch_utils
import
load_pytorch_checkpoint_in_tf2_model
logger
=
logging
.
getLogger
(
__name__
)
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
=
{
"ctrl"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/ctrl-tf_model.h5"
}
def
load_ctrl_pt_weights_in_tf2
(
tf_model
,
pytorch_checkpoint_path
):
# build the network
inputs_list
=
[[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]]
tf_inputs
=
tf
.
constant
(
inputs_list
)
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
return
load_pytorch_checkpoint_in_tf2_model
(
tf_model
,
pytorch_checkpoint_path
,
tf_inputs
=
tf_inputs
)
def
angle_defn
(
pos
,
i
,
d_model_size
):
angle_rates
=
1
/
np
.
power
(
10000
,
(
2
*
(
i
//
2
))
/
np
.
float32
(
d_model_size
))
return
pos
*
angle_rates
...
...
@@ -327,7 +318,6 @@ class TFCTRLPreTrainedModel(TFPreTrainedModel):
config_class
=
CTRLConfig
pretrained_model_archive_map
=
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix
=
"transformer"
load_pt_weights
=
load_ctrl_pt_weights_in_tf2
CTRL_START_DOCSTRING
=
r
""" CTRL model was proposed in
...
...
transformers/modeling_tf_distilbert.py
View file @
da26bae6
...
...
@@ -31,7 +31,6 @@ import tensorflow as tf
from
.configuration_distilbert
import
DistilBertConfig
from
.modeling_tf_utils
import
TFPreTrainedModel
,
TFSharedEmbeddings
,
shape_list
,
get_initializer
from
.file_utils
import
add_start_docstrings
from
.modeling_tf_pytorch_utils
import
load_pytorch_checkpoint_in_tf2_model
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -66,14 +65,6 @@ def gelu_new(x):
(
np
.
sqrt
(
2
/
np
.
pi
)
*
(
x
+
0.044715
*
tf
.
pow
(
x
,
3
)))))
return
x
*
cdf
def
load_distilbert_pt_weights_in_tf2
(
tf_model
,
pytorch_checkpoint_path
):
# build the network
inputs_list
=
tf
.
constant
([[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]])
attns_list
=
tf
.
constant
([[
1
,
1
,
0
,
0
,
1
],
[
1
,
1
,
1
,
0
,
0
],
[
1
,
0
,
0
,
1
,
1
]])
tf_inputs
=
[
inputs_list
,
attns_list
]
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
return
load_pytorch_checkpoint_in_tf2_model
(
tf_model
,
pytorch_checkpoint_path
,
tf_inputs
=
tf_inputs
)
class
TFEmbeddings
(
tf
.
keras
.
layers
.
Layer
):
def
__init__
(
self
,
config
,
**
kwargs
):
super
(
TFEmbeddings
,
self
).
__init__
(
**
kwargs
)
...
...
@@ -454,7 +445,6 @@ class TFDistilBertPreTrainedModel(TFPreTrainedModel):
"""
config_class
=
DistilBertConfig
pretrained_model_archive_map
=
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
load_pt_weights
=
load_distilbert_pt_weights_in_tf2
base_model_prefix
=
"distilbert"
...
...
transformers/modeling_tf_gpt2.py
View file @
da26bae6
...
...
@@ -32,7 +32,6 @@ from .modeling_tf_utils import (TFPreTrainedModel, TFConv1D, TFSharedEmbeddings,
TFSequenceSummary
,
shape_list
,
get_initializer
)
from
.configuration_gpt2
import
GPT2Config
from
.file_utils
import
add_start_docstrings
from
.modeling_tf_pytorch_utils
import
load_pytorch_checkpoint_in_tf2_model
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -42,14 +41,6 @@ TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models
"distilgpt2"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/distilgpt2-tf_model.h5"
,}
def
load_gpt2_pt_weights_in_tf2
(
tf_model
,
pytorch_checkpoint_path
):
# build the network
inputs_list
=
[[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]]
tf_inputs
=
tf
.
constant
(
inputs_list
)
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
return
load_pytorch_checkpoint_in_tf2_model
(
tf_model
,
pytorch_checkpoint_path
,
tf_inputs
=
tf_inputs
)
def
gelu
(
x
):
"""Gaussian Error Linear Unit.
This is a smoother version of the RELU.
...
...
@@ -350,7 +341,6 @@ class TFGPT2PreTrainedModel(TFPreTrainedModel):
"""
config_class
=
GPT2Config
pretrained_model_archive_map
=
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
load_pt_weights
=
load_gpt2_pt_weights_in_tf2
base_model_prefix
=
"transformer"
...
...
transformers/modeling_tf_openai.py
View file @
da26bae6
...
...
@@ -32,21 +32,12 @@ from .modeling_tf_utils import (TFPreTrainedModel, TFConv1D, TFSharedEmbeddings,
TFSequenceSummary
,
shape_list
,
get_initializer
)
from
.configuration_openai
import
OpenAIGPTConfig
from
.file_utils
import
add_start_docstrings
from
.modeling_tf_pytorch_utils
import
load_pytorch_checkpoint_in_tf2_model
logger
=
logging
.
getLogger
(
__name__
)
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
=
{
"openai-gpt"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-tf_model.h5"
}
def
load_openai_gpt_pt_weights_in_tf2
(
tf_model
,
pytorch_checkpoint_path
):
# build the network
inputs_list
=
[[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]]
tf_inputs
=
tf
.
constant
(
inputs_list
)
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
return
load_pytorch_checkpoint_in_tf2_model
(
tf_model
,
pytorch_checkpoint_path
,
tf_inputs
=
tf_inputs
)
def
gelu
(
x
):
"""Gaussian Error Linear Unit.
This is a smoother version of the RELU.
...
...
@@ -335,7 +326,6 @@ class TFOpenAIGPTPreTrainedModel(TFPreTrainedModel):
"""
config_class
=
OpenAIGPTConfig
pretrained_model_archive_map
=
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
load_pt_weights
=
load_openai_gpt_pt_weights_in_tf2
base_model_prefix
=
"transformer"
...
...
transformers/modeling_tf_pytorch_utils.py
View file @
da26bae6
...
...
@@ -25,8 +25,6 @@ import numpy
logger
=
logging
.
getLogger
(
__name__
)
DUMMY_INPUTS
=
[[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]]
def
convert_tf_weight_name_to_pt_weight_name
(
tf_name
,
start_prefix_to_remove
=
''
):
""" Convert a TF 2.0 model variable name in a pytorch model weight name.
...
...
@@ -105,7 +103,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
raise
e
if
tf_inputs
is
None
:
tf_inputs
=
tf
.
constant
(
DUMMY_INPUTS
)
tf_inputs
=
tf
_model
.
dummy_inputs
if
tf_inputs
is
not
None
:
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
# Make sure model is built
...
...
transformers/modeling_tf_roberta.py
View file @
da26bae6
...
...
@@ -26,7 +26,6 @@ import tensorflow as tf
from
.configuration_roberta
import
RobertaConfig
from
.modeling_tf_utils
import
TFPreTrainedModel
,
get_initializer
from
.file_utils
import
add_start_docstrings
from
.modeling_tf_pytorch_utils
import
load_pytorch_checkpoint_in_tf2_model
from
.modeling_tf_bert
import
TFBertEmbeddings
,
TFBertMainLayer
,
gelu
,
gelu_new
...
...
@@ -38,14 +37,6 @@ TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP = {
'roberta-large-mnli'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-tf_model.h5"
,
}
def
load_roberta_pt_weights_in_tf2
(
tf_model
,
pytorch_checkpoint_path
):
# build the network
inputs_list
=
[[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]]
tf_inputs
=
tf
.
constant
(
inputs_list
)
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
return
load_pytorch_checkpoint_in_tf2_model
(
tf_model
,
pytorch_checkpoint_path
,
tf_inputs
=
tf_inputs
)
class
TFRobertaEmbeddings
(
TFBertEmbeddings
):
"""
Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
...
...
@@ -96,7 +87,6 @@ class TFRobertaPreTrainedModel(TFPreTrainedModel):
"""
config_class
=
RobertaConfig
pretrained_model_archive_map
=
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
load_pt_weights
=
load_roberta_pt_weights_in_tf2
base_model_prefix
=
"roberta"
...
...
transformers/modeling_tf_transfo_xl.py
View file @
da26bae6
...
...
@@ -33,7 +33,6 @@ from .configuration_transfo_xl import TransfoXLConfig
from
.modeling_tf_utils
import
TFPreTrainedModel
,
TFConv1D
,
TFSequenceSummary
,
shape_list
,
get_initializer
from
.modeling_tf_transfo_xl_utilities
import
TFAdaptiveSoftmaxMask
from
.file_utils
import
add_start_docstrings
from
.modeling_tf_pytorch_utils
import
load_pytorch_checkpoint_in_tf2_model
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -41,14 +40,6 @@ TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP = {
'transfo-xl-wt103'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-tf_model.h5"
,
}
def
load_transfo_xl_pt_weights_in_tf2
(
tf_model
,
pytorch_checkpoint_path
):
# build the network
inputs_list
=
[[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]]
tf_inputs
=
tf
.
constant
(
inputs_list
)
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
return
load_pytorch_checkpoint_in_tf2_model
(
tf_model
,
pytorch_checkpoint_path
,
tf_inputs
=
tf_inputs
)
class
TFPositionalEmbedding
(
tf
.
keras
.
layers
.
Layer
):
def
__init__
(
self
,
demb
,
**
kwargs
):
super
(
TFPositionalEmbedding
,
self
).
__init__
(
**
kwargs
)
...
...
@@ -577,7 +568,6 @@ class TFTransfoXLPreTrainedModel(TFPreTrainedModel):
"""
config_class
=
TransfoXLConfig
pretrained_model_archive_map
=
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
load_pt_weights
=
load_transfo_xl_pt_weights_in_tf2
base_model_prefix
=
"transformer"
...
...
transformers/modeling_tf_utils.py
View file @
da26bae6
...
...
@@ -25,9 +25,11 @@ import tensorflow as tf
from
.configuration_utils
import
PretrainedConfig
from
.file_utils
import
cached_path
,
WEIGHTS_NAME
,
TF_WEIGHTS_NAME
,
TF2_WEIGHTS_NAME
from
.modeling_tf_pytorch_utils
import
load_pytorch_checkpoint_in_tf2_model
logger
=
logging
.
getLogger
(
__name__
)
DUMMY_INPUTS
=
[[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]]
class
TFPreTrainedModel
(
tf
.
keras
.
Model
):
r
""" Base class for all TF models.
...
...
@@ -48,8 +50,8 @@ class TFPreTrainedModel(tf.keras.Model):
"""
config_class
=
None
pretrained_model_archive_map
=
{}
load_pt_weights
=
lambda
model
,
config
,
path
:
None
base_model_prefix
=
""
dummy_inputs
=
tf
.
constant
(
DUMMY_INPUTS
)
# dummy inputs to build the network
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
super
(
TFPreTrainedModel
,
self
).
__init__
(
*
inputs
,
**
kwargs
)
...
...
@@ -262,17 +264,16 @@ class TFPreTrainedModel(tf.keras.Model):
if
from_pt
:
# Load from a PyTorch checkpoint
return
cls
.
load_p
t_weights
(
model
,
resolved_archive_file
)
return
load_p
ytorch_checkpoint_in_tf2_model
(
model
,
resolved_archive_file
)
inputs
=
tf
.
constant
([[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]])
ret
=
model
(
inputs
,
training
=
False
)
# build the network with dummy inputs
ret
=
model
(
model
.
dummy_inputs
,
training
=
False
)
# build the network with dummy inputs
assert
os
.
path
.
isfile
(
resolved_archive_file
),
"Error retrieving file {}"
.
format
(
resolved_archive_file
)
# 'by_name' allow us to do transfer learning by skipping/adding layers
# see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
model
.
load_weights
(
resolved_archive_file
,
by_name
=
True
)
ret
=
model
(
inputs
,
training
=
False
)
# Make sure restore ops are run
ret
=
model
(
model
.
dummy_
inputs
,
training
=
False
)
# Make sure restore ops are run
return
model
...
...
transformers/modeling_tf_xlm.py
View file @
da26bae6
...
...
@@ -25,9 +25,8 @@ import numpy as np
import
tensorflow
as
tf
from
.configuration_xlm
import
XLMConfig
from
.modeling_tf_utils
import
TFPreTrainedModel
,
TFSharedEmbeddings
,
TFSequenceSummary
,
shape_list
,
get_initializer
from
.modeling_tf_utils
import
TFPreTrainedModel
,
TFSharedEmbeddings
,
TFSequenceSummary
,
shape_list
,
get_initializer
,
DUMMY_INPUTS
from
.file_utils
import
add_start_docstrings
from
.modeling_tf_pytorch_utils
import
load_pytorch_checkpoint_in_tf2_model
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -45,19 +44,6 @@ TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP = {
}
def
load_xlm_pt_weights_in_tf2
(
tf_model
,
pytorch_checkpoint_path
):
# build the network
inputs_list
=
tf
.
constant
([[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]])
attns_list
=
tf
.
constant
([[
1
,
1
,
0
,
0
,
1
],
[
1
,
1
,
1
,
0
,
0
],
[
1
,
0
,
0
,
1
,
1
]])
if
tf_model
.
config
.
use_lang_emb
and
tf_model
.
config
.
n_langs
>
1
:
langs_list
=
tf
.
constant
([[
1
,
1
,
0
,
0
,
1
],
[
1
,
1
,
1
,
0
,
0
],
[
1
,
0
,
0
,
1
,
1
]])
else
:
langs_list
=
None
tf_inputs
=
[
inputs_list
,
attns_list
,
langs_list
]
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
return
load_pytorch_checkpoint_in_tf2_model
(
tf_model
,
pytorch_checkpoint_path
,
tf_inputs
=
tf_inputs
)
def
create_sinusoidal_embeddings
(
n_pos
,
dim
,
out
):
position_enc
=
np
.
array
([
[
pos
/
np
.
power
(
10000
,
2
*
(
j
//
2
)
/
dim
)
for
j
in
range
(
dim
)]
...
...
@@ -441,9 +427,19 @@ class TFXLMPreTrainedModel(TFPreTrainedModel):
"""
config_class
=
XLMConfig
pretrained_model_archive_map
=
TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP
load_pt_weights
=
load_xlm_pt_weights_in_tf2
base_model_prefix
=
"transformer"
@
property
def
dummy_inputs
(
self
):
# Sometimes XLM has language embeddings so don't forget to build them as well if needed
inputs_list
=
tf
.
constant
([[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]])
attns_list
=
tf
.
constant
([[
1
,
1
,
0
,
0
,
1
],
[
1
,
1
,
1
,
0
,
0
],
[
1
,
0
,
0
,
1
,
1
]])
if
self
.
config
.
use_lang_emb
and
self
.
config
.
n_langs
>
1
:
langs_list
=
tf
.
constant
([[
1
,
1
,
0
,
0
,
1
],
[
1
,
1
,
1
,
0
,
0
],
[
1
,
0
,
0
,
1
,
1
]])
else
:
langs_list
=
None
return
[
inputs_list
,
attns_list
,
langs_list
]
XLM_START_DOCSTRING
=
r
""" The XLM model was proposed in
`Cross-lingual Language Model Pretraining`_
...
...
transformers/modeling_tf_xlnet.py
View file @
da26bae6
...
...
@@ -30,7 +30,6 @@ import tensorflow as tf
from
.configuration_xlnet
import
XLNetConfig
from
.modeling_tf_utils
import
TFPreTrainedModel
,
TFSharedEmbeddings
,
TFSequenceSummary
,
shape_list
,
get_initializer
from
.file_utils
import
add_start_docstrings
from
.modeling_tf_pytorch_utils
import
load_pytorch_checkpoint_in_tf2_model
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -41,13 +40,6 @@ TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP = {
}
def
load_xlnet_pt_weights_in_tf2
(
tf_model
,
pytorch_checkpoint_path
):
inputs_list
=
[[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]]
tf_inputs
=
tf
.
constant
(
inputs_list
)
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
# build the network
return
load_pytorch_checkpoint_in_tf2_model
(
tf_model
,
pytorch_checkpoint_path
,
tf_inputs
=
tf_inputs
)
def
gelu
(
x
):
""" Implementation of the gelu activation function.
XLNet is using OpenAI GPT's gelu
...
...
@@ -670,7 +662,6 @@ class TFXLNetPreTrainedModel(TFPreTrainedModel):
"""
config_class
=
XLNetConfig
pretrained_model_archive_map
=
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
load_pt_weights
=
load_xlnet_pt_weights_in_tf2
base_model_prefix
=
"transformer"
...
...
transformers/tests/modeling_common_test.py
View file @
da26bae6
...
...
@@ -17,8 +17,10 @@ from __future__ import division
from
__future__
import
print_function
import
copy
import
sys
import
os
import
shutil
import
tempfile
import
json
import
random
import
uuid
...
...
@@ -31,6 +33,7 @@ from transformers import is_torch_available
if
is_torch_available
():
import
torch
import
numpy
as
np
from
transformers
import
(
PretrainedConfig
,
PreTrainedModel
,
BertModel
,
BertConfig
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
...
...
@@ -38,6 +41,20 @@ if is_torch_available():
else
:
pytestmark
=
pytest
.
mark
.
skip
(
"Require Torch"
)
if
sys
.
version_info
[
0
]
==
2
:
import
cPickle
as
pickle
class
TemporaryDirectory
(
object
):
"""Context manager for tempfile.mkdtemp() so it's usable with "with" statement."""
def
__enter__
(
self
):
self
.
name
=
tempfile
.
mkdtemp
()
return
self
.
name
def
__exit__
(
self
,
exc_type
,
exc_value
,
traceback
):
shutil
.
rmtree
(
self
.
name
)
else
:
import
pickle
TemporaryDirectory
=
tempfile
.
TemporaryDirectory
unicode
=
str
def
_config_zero_init
(
config
):
configs_no_init
=
copy
.
deepcopy
(
config
)
...
...
@@ -57,6 +74,23 @@ class CommonTestCases:
test_resize_embeddings
=
True
test_head_masking
=
True
def
test_save_load
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
for
model_class
in
self
.
all_model_classes
:
model
=
model_class
(
config
)
model
.
eval
()
with
torch
.
no_grad
():
outputs
=
model
(
**
inputs_dict
)
with
TemporaryDirectory
()
as
tmpdirname
:
model
.
save_pretrained
(
tmpdirname
)
model
=
model_class
.
from_pretrained
(
tmpdirname
)
with
torch
.
no_grad
():
after_outputs
=
model
(
**
inputs_dict
)
max_diff
=
np
.
amax
(
np
.
abs
(
after_outputs
[
0
].
numpy
()
-
outputs
[
0
].
numpy
()))
self
.
assertLessEqual
(
max_diff
,
1e-5
)
def
test_initialization
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment