Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
2b11fa51
Commit
2b11fa51
authored
Sep 23, 2019
by
thomwolf
Browse files
update __init__ and conversion script
parent
6448396d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
53 additions
and
20 deletions
+53
-20
pytorch_transformers/__init__.py
pytorch_transformers/__init__.py
+18
-0
pytorch_transformers/convert_pytorch_checkpoint_to_tf2.py
pytorch_transformers/convert_pytorch_checkpoint_to_tf2.py
+35
-20
No files found.
pytorch_transformers/__init__.py
View file @
2b11fa51
...
@@ -113,6 +113,11 @@ if _tf_available:
...
@@ -113,6 +113,11 @@ if _tf_available:
load_gpt2_pt_weights_in_tf2
,
load_gpt2_pt_weights_in_tf2
,
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
)
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_tf_openai
import
(
TFOpenAIGPTPreTrainedModel
,
TFOpenAIGPTMainLayer
,
TFOpenAIGPTModel
,
TFOpenAIGPTLMHeadModel
,
TFOpenAIGPTDoubleHeadsModel
,
load_openai_gpt_pt_weights_in_tf2
,
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_tf_transfo_xl
import
(
TFTransfoXLPreTrainedModel
,
TFTransfoXLMainLayer
,
from
.modeling_tf_transfo_xl
import
(
TFTransfoXLPreTrainedModel
,
TFTransfoXLMainLayer
,
TFTransfoXLModel
,
TFTransfoXLLMHeadModel
,
TFTransfoXLModel
,
TFTransfoXLLMHeadModel
,
load_transfo_xl_pt_weights_in_tf2
,
load_transfo_xl_pt_weights_in_tf2
,
...
@@ -132,6 +137,19 @@ if _tf_available:
...
@@ -132,6 +137,19 @@ if _tf_available:
load_xlm_pt_weights_in_tf2
,
load_xlm_pt_weights_in_tf2
,
TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP
)
TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_tf_roberta
import
(
TFRobertaPreTrainedModel
,
TFRobertaMainLayer
,
TFRobertaModel
,
TFRobertaLMHead
,
TFRobertaForSequenceClassification
,
load_roberta_pt_weights_in_tf2
,
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_tf_distilbert
import
(
TFDistilBertPreTrainedModel
,
TFDistilBertMainLayer
,
TFDistilBertModel
,
TFDistilBertForMaskedLM
,
TFDistilBertForSequenceClassification
,
TFDistilBertForSequenceClassification
,
load_distilbert_pt_weights_in_tf2
,
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
)
# Files and general utilities
# Files and general utilities
from
.file_utils
import
(
PYTORCH_TRANSFORMERS_CACHE
,
PYTORCH_PRETRAINED_BERT_CACHE
,
from
.file_utils
import
(
PYTORCH_TRANSFORMERS_CACHE
,
PYTORCH_PRETRAINED_BERT_CACHE
,
cached_path
,
add_start_docstrings
,
add_end_docstrings
,
cached_path
,
add_start_docstrings
,
add_end_docstrings
,
...
...
pytorch_transformers/convert_pytorch_checkpoint_to_tf2.py
View file @
2b11fa51
...
@@ -24,31 +24,43 @@ import tensorflow as tf
...
@@ -24,31 +24,43 @@ import tensorflow as tf
from
pytorch_transformers
import
is_torch_available
,
cached_path
from
pytorch_transformers
import
is_torch_available
,
cached_path
from
pytorch_transformers
import
(
BertConfig
,
TFBertForPreTraining
,
load_bert_pt_weights_in_tf2
,
from
pytorch_transformers
import
(
BertConfig
,
TFBertForPreTraining
,
load_bert_pt_weights_in_tf2
,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
GPT2Config
,
TFGPT2LMHeadModel
,
load_gpt2_pt_weights_in_tf2
,
GPT2Config
,
TFGPT2LMHeadModel
,
load_gpt2_pt_weights_in_tf2
,
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
,
XLNetConfig
,
TFXLNetLMHeadModel
,
load_xlnet_pt_weights_in_tf2
,
XLNetConfig
,
TFXLNetLMHeadModel
,
load_xlnet_pt_weights_in_tf2
,
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
,
XLMConfig
,
TFXLMWithLMHeadModel
,
load_xlm_pt_weights_in_tf2
,
XLMConfig
,
TFXLMWithLMHeadModel
,
load_xlm_pt_weights_in_tf2
,
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
,
TransfoXLConfig
,
TFTransfoXLLMHeadModel
,
load_transfo_xl_pt_weights_in_tf2
,)
TransfoXLConfig
,
TFTransfoXLLMHeadModel
,
load_transfo_xl_pt_weights_in_tf2
,
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP
,
OpenAIGPTConfig
,
TFOpenAIGPTLMHeadModel
,
load_openai_gpt_pt_weights_in_tf2
,
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
RobertaConfig
,
TFRobertaLMHead
,
load_roberta_pt_weights_in_tf2
,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
,
DistilBertConfig
,
TFDistilBertForMaskedLM
,
load_distilbert_pt_weights_in_tf2
,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
)
if
is_torch_available
():
if
is_torch_available
():
import
torch
import
torch
import
numpy
as
np
import
numpy
as
np
from
pytorch_transformers
import
(
BertForPreTraining
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
from
pytorch_transformers
import
(
BertForPreTraining
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
GPT2LMHeadModel
,
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
,
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
,
GPT2LMHeadModel
,
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
,
XLNetLMHeadModel
,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
,
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
,
XLNetLMHeadModel
,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
,
XLMWithLMHeadModel
,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP
,
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
,
XLMWithLMHeadModel
,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP
,
TransfoXLLMHeadModel
,
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
,
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP
,)
TransfoXLLMHeadModel
,
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
,
OpenAIGPTLMHeadModel
,
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
,
RobertaForMaskedLM
,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
,
DistilBertForMaskedLM
,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
)
else
:
else
:
(
BertForPreTraining
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
(
BertForPreTraining
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
GPT2LMHeadModel
,
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
,
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
,
GPT2LMHeadModel
,
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
,
XLNetLMHeadModel
,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
,
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
,
XLNetLMHeadModel
,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
,
XLMWithLMHeadModel
,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP
,
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
,
XLMWithLMHeadModel
,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP
,
TransfoXLLMHeadModel
,
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
,
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP
,)
=
(
TransfoXLLMHeadModel
,
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
,
None
,
None
,
None
,
OpenAIGPTLMHeadModel
,
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
,
None
,
None
,
None
,
RobertaForMaskedLM
,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
,
None
,
None
,
None
,
DistilBertForMaskedLM
,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,)
=
(
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,)
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,)
import
logging
import
logging
...
@@ -60,6 +72,9 @@ MODEL_CLASSES = {
...
@@ -60,6 +72,9 @@ MODEL_CLASSES = {
'xlnet'
:
(
XLNetConfig
,
TFXLNetLMHeadModel
,
load_xlnet_pt_weights_in_tf2
,
XLNetLMHeadModel
,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
,
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'xlnet'
:
(
XLNetConfig
,
TFXLNetLMHeadModel
,
load_xlnet_pt_weights_in_tf2
,
XLNetLMHeadModel
,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
,
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'xlm'
:
(
XLMConfig
,
TFXLMWithLMHeadModel
,
load_xlm_pt_weights_in_tf2
,
XLMWithLMHeadModel
,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP
,
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'xlm'
:
(
XLMConfig
,
TFXLMWithLMHeadModel
,
load_xlm_pt_weights_in_tf2
,
XLMWithLMHeadModel
,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP
,
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'transfo-xl'
:
(
TransfoXLConfig
,
TFTransfoXLLMHeadModel
,
load_transfo_xl_pt_weights_in_tf2
,
TransfoXLLMHeadModel
,
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
,
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'transfo-xl'
:
(
TransfoXLConfig
,
TFTransfoXLLMHeadModel
,
load_transfo_xl_pt_weights_in_tf2
,
TransfoXLLMHeadModel
,
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
,
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'openai-gpt'
:
(
OpenAIGPTConfig
,
TFOpenAIGPTLMHeadModel
,
load_openai_gpt_pt_weights_in_tf2
,
OpenAIGPTLMHeadModel
,
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
,
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'roberta'
:
(
RobertaConfig
,
TFRobertaLMHead
,
load_roberta_pt_weights_in_tf2
,
RobertaForMaskedLM
,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'distilbert'
:
(
DistilBertConfig
,
TFDistilBertForMaskedLM
,
load_distilbert_pt_weights_in_tf2
,
DistilBertForMaskedLM
,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
}
}
def
convert_pt_checkpoint_to_tf
(
model_type
,
pytorch_checkpoint_path
,
config_file
,
tf_dump_path
,
compare_with_pt_model
=
False
):
def
convert_pt_checkpoint_to_tf
(
model_type
,
pytorch_checkpoint_path
,
config_file
,
tf_dump_path
,
compare_with_pt_model
=
False
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment