Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
48b438ff
Commit
48b438ff
authored
Oct 09, 2019
by
thomwolf
Browse files
doc and conversion
parent
c19b8e4a
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
18 additions
and
4 deletions
+18
-4
docs/source/pretrained_models.rst
docs/source/pretrained_models.rst
+3
-0
transformers/__init__.py
transformers/__init__.py
+5
-0
transformers/convert_pytorch_checkpoint_to_tf2.py
transformers/convert_pytorch_checkpoint_to_tf2.py
+10
-4
No files found.
docs/source/pretrained_models.rst
View file @
48b438ff
...
...
@@ -119,5 +119,8 @@ Here is the full list of the currently provided pretrained models together with
| | | | The DistilBERT model distilled from the BERT model `bert-base-uncased` checkpoint, with an additional linear layer. |
| | | (see `details <https://medium.com/huggingface/distilbert-8cf3380435b5>`__) |
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| CTRL | ``ctrl`` | | 48-layer, 1280-hidden, 16-heads, 1.6B parameters |
| | | | Salesforce's Large-sized CTRL English model |
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
.. <https://huggingface.co/transformers/examples.html>`__
\ No newline at end of file
transformers/__init__.py
View file @
48b438ff
...
...
@@ -155,6 +155,11 @@ if is_tf_available():
load_distilbert_pt_weights_in_tf2
,
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_tf_ctrl
import
(
TFCTRLPreTrainedModel
,
TFCTRLModel
,
TFCTRLLMHeadModel
,
load_ctrl_pt_weights_in_tf2
,
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
)
# TF 2.0 <=> PyTorch conversion utilities
if
is_tf_available
()
and
is_torch_available
():
from
.modeling_tf_pytorch_utils
import
(
convert_tf_weight_name_to_pt_weight_name
,
...
...
transformers/convert_pytorch_checkpoint_to_tf2.py
View file @
48b438ff
...
...
@@ -31,7 +31,8 @@ from transformers import (BertConfig, TFBertForPreTraining, TFBertForQuestionAns
TransfoXLConfig
,
TFTransfoXLLMHeadModel
,
load_transfo_xl_pt_weights_in_tf2
,
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP
,
OpenAIGPTConfig
,
TFOpenAIGPTLMHeadModel
,
load_openai_gpt_pt_weights_in_tf2
,
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
RobertaConfig
,
TFRobertaForMaskedLM
,
TFRobertaForSequenceClassification
,
load_roberta_pt_weights_in_tf2
,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
,
DistilBertConfig
,
TFDistilBertForMaskedLM
,
TFDistilBertForQuestionAnswering
,
load_distilbert_pt_weights_in_tf2
,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
)
DistilBertConfig
,
TFDistilBertForMaskedLM
,
TFDistilBertForQuestionAnswering
,
load_distilbert_pt_weights_in_tf2
,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
CTRLConfig
,
TFCTRLLMHeadModel
,
load_ctrl_pt_weights_in_tf2
,
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP
)
if
is_torch_available
():
import
torch
...
...
@@ -43,7 +44,8 @@ if is_torch_available():
TransfoXLLMHeadModel
,
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
,
OpenAIGPTLMHeadModel
,
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
,
RobertaForMaskedLM
,
RobertaForSequenceClassification
,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
,
DistilBertForMaskedLM
,
DistilBertForQuestionAnswering
,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
)
DistilBertForMaskedLM
,
DistilBertForQuestionAnswering
,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
CTRLLMHeadModel
,
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
)
else
:
(
BertForPreTraining
,
BertForQuestionAnswering
,
BertForSequenceClassification
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
GPT2LMHeadModel
,
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
,
...
...
@@ -52,7 +54,8 @@ else:
TransfoXLLMHeadModel
,
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
,
OpenAIGPTLMHeadModel
,
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
,
RobertaForMaskedLM
,
RobertaForSequenceClassification
,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
,
DistilBertForMaskedLM
,
DistilBertForQuestionAnswering
,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,)
=
(
DistilBertForMaskedLM
,
DistilBertForQuestionAnswering
,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
CTRLLMHeadModel
,
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
)
=
(
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
...
...
@@ -60,7 +63,8 @@ else:
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,)
None
,
None
,
None
,
None
,
None
)
import
logging
...
...
@@ -80,6 +84,7 @@ MODEL_CLASSES = {
'roberta-large-mnli'
:
(
RobertaConfig
,
TFRobertaForSequenceClassification
,
load_roberta_pt_weights_in_tf2
,
RobertaForSequenceClassification
,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'distilbert'
:
(
DistilBertConfig
,
TFDistilBertForMaskedLM
,
load_distilbert_pt_weights_in_tf2
,
DistilBertForMaskedLM
,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'distilbert-base-uncased-distilled-squad'
:
(
DistilBertConfig
,
TFDistilBertForQuestionAnswering
,
load_distilbert_pt_weights_in_tf2
,
DistilBertForQuestionAnswering
,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'ctrl'
:
(
CTRLConfig
,
TFCTRLLMHeadModel
,
load_ctrl_pt_weights_in_tf2
,
CTRLLMHeadModel
,
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
,
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP
)
}
def
convert_pt_checkpoint_to_tf
(
model_type
,
pytorch_checkpoint_path
,
config_file
,
tf_dump_path
,
compare_with_pt_model
=
False
,
use_cached_models
=
True
):
...
...
@@ -228,6 +233,7 @@ if __name__ == "__main__":
convert_all_pt_checkpoints_to_tf
(
args
.
model_type
.
lower
()
if
args
.
model_type
is
not
None
else
None
,
args
.
tf_dump_path
,
model_shortcut_names_or_path
=
[
args
.
pytorch_checkpoint_path
]
if
args
.
pytorch_checkpoint_path
is
not
None
else
None
,
config_shortcut_names_or_path
=
[
args
.
config_file
]
if
args
.
config_file
is
not
None
else
None
,
compare_with_pt_model
=
args
.
compare_with_pt_model
,
use_cached_models
=
args
.
use_cached_models
,
only_convert_finetuned_models
=
args
.
only_convert_finetuned_models
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment