Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
f6f38253
"tests/vscode:/vscode.git/clone" did not exist on "1ac07d8a8dd38fb155da48e5aafbfb63b3958520"
Commit
f6f38253
authored
Nov 07, 2019
by
Lysandre
Committed by
Lysandre Debut
Nov 26, 2019
Browse files
ALBERT in TF2
parent
d9daad98
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
736 additions
and
6 deletions
+736
-6
transformers/__init__.py
transformers/__init__.py
+4
-2
transformers/convert_pytorch_checkpoint_to_tf2.py
transformers/convert_pytorch_checkpoint_to_tf2.py
+9
-4
transformers/modeling_tf_albert.py
transformers/modeling_tf_albert.py
+723
-0
No files found.
transformers/__init__.py
View file @
f6f38253
...
...
@@ -58,8 +58,7 @@ from .configuration_ctrl import CTRLConfig, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP
from
.configuration_xlm
import
XLMConfig
,
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
from
.configuration_roberta
import
RobertaConfig
,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
from
.configuration_distilbert
import
DistilBertConfig
,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
from
.configuration_albert
import
AlbertConfig
,
ALBERT
from
.configuration_albert
import
AlbertConfig
from
.configuration_albert
import
AlbertConfig
,
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
from
.configuration_camembert
import
CamembertConfig
,
CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
# Modeling
...
...
@@ -169,6 +168,9 @@ if is_tf_available():
TFCTRLLMHeadModel
,
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_tf_albert
import
(
TFAlbertPreTrainedModel
,
TFAlbertModel
,
TFAlbertForMaskedLM
,
TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
)
# TF 2.0 <=> PyTorch conversion utilities
from
.modeling_tf_pytorch_utils
import
(
convert_tf_weight_name_to_pt_weight_name
,
load_pytorch_checkpoint_in_tf2_model
,
...
...
transformers/convert_pytorch_checkpoint_to_tf2.py
View file @
f6f38253
...
...
@@ -33,7 +33,8 @@ from transformers import (load_pytorch_checkpoint_in_tf2_model,
OpenAIGPTConfig
,
TFOpenAIGPTLMHeadModel
,
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
RobertaConfig
,
TFRobertaForMaskedLM
,
TFRobertaForSequenceClassification
,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
,
DistilBertConfig
,
TFDistilBertForMaskedLM
,
TFDistilBertForQuestionAnswering
,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
CTRLConfig
,
TFCTRLLMHeadModel
,
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP
)
CTRLConfig
,
TFCTRLLMHeadModel
,
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP
,
AlbertConfig
,
TFAlbertForMaskedLM
,
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
)
if
is_torch_available
():
import
torch
...
...
@@ -46,7 +47,8 @@ if is_torch_available():
OpenAIGPTLMHeadModel
,
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
,
RobertaForMaskedLM
,
RobertaForSequenceClassification
,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
,
DistilBertForMaskedLM
,
DistilBertForQuestionAnswering
,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
CTRLLMHeadModel
,
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
)
CTRLLMHeadModel
,
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
,
AlbertForMaskedLM
,
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
)
else
:
(
BertForPreTraining
,
BertForQuestionAnswering
,
BertForSequenceClassification
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
GPT2LMHeadModel
,
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
,
...
...
@@ -56,7 +58,8 @@ else:
OpenAIGPTLMHeadModel
,
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
,
RobertaForMaskedLM
,
RobertaForSequenceClassification
,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
,
DistilBertForMaskedLM
,
DistilBertForQuestionAnswering
,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
CTRLLMHeadModel
,
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
)
=
(
CTRLLMHeadModel
,
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
,
AlbertForMaskedLM
,
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
)
=
(
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
...
...
@@ -65,6 +68,7 @@ else:
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
)
...
...
@@ -85,7 +89,8 @@ MODEL_CLASSES = {
'roberta-large-mnli'
:
(
RobertaConfig
,
TFRobertaForSequenceClassification
,
RobertaForSequenceClassification
,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'distilbert'
:
(
DistilBertConfig
,
TFDistilBertForMaskedLM
,
DistilBertForMaskedLM
,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'distilbert-base-uncased-distilled-squad'
:
(
DistilBertConfig
,
TFDistilBertForQuestionAnswering
,
DistilBertForQuestionAnswering
,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'ctrl'
:
(
CTRLConfig
,
TFCTRLLMHeadModel
,
CTRLLMHeadModel
,
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
,
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP
)
'ctrl'
:
(
CTRLConfig
,
TFCTRLLMHeadModel
,
CTRLLMHeadModel
,
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
,
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP
),
'albert'
:
(
AlbertConfig
,
TFAlbertForMaskedLM
,
AlbertForMaskedLM
,
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
)
}
def
convert_pt_checkpoint_to_tf
(
model_type
,
pytorch_checkpoint_path
,
config_file
,
tf_dump_path
,
compare_with_pt_model
=
False
,
use_cached_models
=
True
):
...
...
transformers/modeling_tf_albert.py
0 → 100644
View file @
f6f38253
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment