Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c9875455
Commit
c9875455
authored
Oct 31, 2019
by
Lysandre
Committed by
Lysandre Debut
Nov 26, 2019
Browse files
Converting script
parent
4f3a54bf
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
30 additions
and
8 deletions
+30
-8
transformers/__init__.py
transformers/__init__.py
+1
-1
transformers/convert_albert_original_tf_checkpoint_to_pytorch.py
...rmers/convert_albert_original_tf_checkpoint_to_pytorch.py
+29
-7
No files found.
transformers/__init__.py
View file @
c9875455
...
@@ -107,7 +107,7 @@ if is_torch_available():
...
@@ -107,7 +107,7 @@ if is_torch_available():
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP
)
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_encoder_decoder
import
PreTrainedEncoderDecoder
,
Model2Model
from
.modeling_encoder_decoder
import
PreTrainedEncoderDecoder
,
Model2Model
from
.modeling_albert
import
(
AlbertModel
,
AlbertForMaskedLM
,
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_albert
import
(
AlbertModel
,
AlbertForMaskedLM
,
load_tf_weights_in_albert
,
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
)
# Optimization
# Optimization
from
.optimization
import
(
AdamW
,
get_constant_schedule
,
get_constant_schedule_with_warmup
,
get_cosine_schedule_with_warmup
,
from
.optimization
import
(
AdamW
,
get_constant_schedule
,
get_constant_schedule_with_warmup
,
get_cosine_schedule_with_warmup
,
...
...
transformers/convert_albert_original_tf_checkpoint_to_pytorch.py
View file @
c9875455
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert ALBERT checkpoint."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
argparse
import
torch
from
transformers
import
AlbertConfig
,
BertForPreTraining
,
load_tf_weights_in_bert
from
transformers
import
AlbertConfig
,
AlbertForMaskedLM
,
load_tf_weights_in_albert
import
logging
logging
.
basicConfig
(
level
=
logging
.
INFO
)
def
convert_tf_checkpoint_to_pytorch
(
tf_checkpoint_path
,
bert_config_file
,
pytorch_dump_path
):
def
convert_tf_checkpoint_to_pytorch
(
tf_checkpoint_path
,
bert_config_file
,
pytorch_dump_path
):
# Initialise PyTorch model
# Initialise PyTorch model
config
=
B
ertConfig
.
from_json_file
(
bert_config_file
)
config
=
Alb
ertConfig
.
from_json_file
(
bert_config_file
)
print
(
"Building PyTorch model from configuration: {}"
.
format
(
str
(
config
)))
print
(
"Building PyTorch model from configuration: {}"
.
format
(
str
(
config
)))
model
=
B
ertFor
PreTraining
(
config
)
model
=
Alb
ertFor
MaskedLM
(
config
)
# Load weights from tf checkpoint
# Load weights from tf checkpoint
load_tf_weights_in_bert
(
model
,
config
,
tf_checkpoint_path
)
load_tf_weights_in_
al
bert
(
model
,
config
,
tf_checkpoint_path
)
# Save pytorch-model
# Save pytorch-model
print
(
"Save PyTorch model to {}"
.
format
(
pytorch_dump_path
))
print
(
"Save PyTorch model to {}"
.
format
(
pytorch_dump_path
))
...
@@ -31,7 +52,7 @@ if __name__ == "__main__":
...
@@ -31,7 +52,7 @@ if __name__ == "__main__":
default
=
None
,
default
=
None
,
type
=
str
,
type
=
str
,
required
=
True
,
required
=
True
,
help
=
"The config json file corresponding to the pre-trained BERT model.
\n
"
help
=
"The config json file corresponding to the pre-trained
AL
BERT model.
\n
"
"This specifies the model architecture."
)
"This specifies the model architecture."
)
parser
.
add_argument
(
"--pytorch_dump_path"
,
parser
.
add_argument
(
"--pytorch_dump_path"
,
default
=
None
,
default
=
None
,
...
@@ -40,5 +61,6 @@ if __name__ == "__main__":
...
@@ -40,5 +61,6 @@ if __name__ == "__main__":
help
=
"Path to the output PyTorch model."
)
help
=
"Path to the output PyTorch model."
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
convert_tf_checkpoint_to_pytorch
(
args
.
tf_checkpoint_path
,
convert_tf_checkpoint_to_pytorch
(
args
.
tf_checkpoint_path
,
args
.
bert_config_file
,
args
.
al
bert_config_file
,
args
.
pytorch_dump_path
)
args
.
pytorch_dump_path
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment