Unverified Commit 41e82912 authored by fgaim's avatar fgaim Committed by GitHub
Browse files

Add ALBERT to the Tensorflow to Pytorch model conversion cli (#3933)

* Add ALBERT to convert command of transformers-cli

* Document ALBERT tf to pytorch model conversion
parent 3f42eb97
...@@ -12,7 +12,7 @@ A command-line interface is provided to convert original Bert/GPT/GPT-2/Transfor ...@@ -12,7 +12,7 @@ A command-line interface is provided to convert original Bert/GPT/GPT-2/Transfor
BERT BERT
^^^^ ^^^^
You can convert any TensorFlow checkpoint for BERT (in particular `the pre-trained models released by Google <https://github.com/google-research/bert#pre-trained-models>`_\ ) in a PyTorch save file by using the `convert_tf_checkpoint_to_pytorch.py <https://github.com/huggingface/transformers/blob/master/transformers/convert_tf_checkpoint_to_pytorch.py>`_ script. You can convert any TensorFlow checkpoint for BERT (in particular `the pre-trained models released by Google <https://github.com/google-research/bert#pre-trained-models>`_\ ) in a PyTorch save file by using the `convert_bert_original_tf_checkpoint_to_pytorch.py <https://github.com/huggingface/transformers/blob/master/src/transformers/convert_bert_original_tf_checkpoint_to_pytorch.py>`_ script.
This CLI takes as input a TensorFlow checkpoint (three files starting with ``bert_model.ckpt``\ ) and the associated configuration file (\ ``bert_config.json``\ ), and creates a PyTorch model for this configuration, loads the weights from the TensorFlow checkpoint in the PyTorch model and saves the resulting model in a standard PyTorch save file that can be imported using ``torch.load()`` (see examples in `run_bert_extract_features.py <https://github.com/huggingface/pytorch-pretrained-BERT/tree/master/examples/run_bert_extract_features.py>`_\ , `run_bert_classifier.py <https://github.com/huggingface/pytorch-pretrained-BERT/tree/master/examples/run_bert_classifier.py>`_ and `run_bert_squad.py <https://github.com/huggingface/pytorch-pretrained-BERT/tree/master/examples/run_bert_squad.py>`_\ ). This CLI takes as input a TensorFlow checkpoint (three files starting with ``bert_model.ckpt``\ ) and the associated configuration file (\ ``bert_config.json``\ ), and creates a PyTorch model for this configuration, loads the weights from the TensorFlow checkpoint in the PyTorch model and saves the resulting model in a standard PyTorch save file that can be imported using ``torch.load()`` (see examples in `run_bert_extract_features.py <https://github.com/huggingface/pytorch-pretrained-BERT/tree/master/examples/run_bert_extract_features.py>`_\ , `run_bert_classifier.py <https://github.com/huggingface/pytorch-pretrained-BERT/tree/master/examples/run_bert_classifier.py>`_ and `run_bert_squad.py <https://github.com/huggingface/pytorch-pretrained-BERT/tree/master/examples/run_bert_squad.py>`_\ ).
...@@ -33,6 +33,26 @@ Here is an example of the conversion process for a pre-trained ``BERT-Base Uncas ...@@ -33,6 +33,26 @@ Here is an example of the conversion process for a pre-trained ``BERT-Base Uncas
You can download Google's pre-trained models for the conversion `here <https://github.com/google-research/bert#pre-trained-models>`__. You can download Google's pre-trained models for the conversion `here <https://github.com/google-research/bert#pre-trained-models>`__.
ALBERT
^^^^^^
Convert TensorFlow model checkpoints of ALBERT to PyTorch using the `convert_albert_original_tf_checkpoint_to_pytorch.py <https://github.com/huggingface/transformers/blob/master/src/transformers/convert_bert_original_tf_checkpoint_to_pytorch.py>`_ script.
The CLI takes as input a TensorFlow checkpoint (three files starting with ``model.ckpt-best``\ ) and the accompanying configuration file (\ ``albert_config.json``\ ), then creates and saves a PyTorch model. To run this conversion you will need to have TensorFlow and PyTorch installed.
Here is an example of the conversion process for the pre-trained ``ALBERT Base`` model:
.. code-block:: shell
export ALBERT_BASE_DIR=/path/to/albert/albert_base
transformers-cli convert --model_type albert \
--tf_checkpoint $ALBERT_BASE_DIR/model.ckpt-best \
--config $ALBERT_BASE_DIR/albert_config.json \
--pytorch_dump_output $ALBERT_BASE_DIR/pytorch_model.bin
You can download Google's pre-trained models for the conversion `here <https://github.com/google-research/albert#pre-trained-models>`__.
OpenAI GPT OpenAI GPT
^^^^^^^^^^ ^^^^^^^^^^
......
...@@ -62,7 +62,21 @@ class ConvertCommand(BaseTransformersCLICommand): ...@@ -62,7 +62,21 @@ class ConvertCommand(BaseTransformersCLICommand):
self._finetuning_task_name = finetuning_task_name self._finetuning_task_name = finetuning_task_name
def run(self): def run(self):
if self._model_type == "bert": if self._model_type == "albert":
try:
from transformers.convert_albert_original_tf_checkpoint_to_pytorch import (
convert_tf_checkpoint_to_pytorch,
)
except ImportError:
msg = (
"transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
"In that case, it requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions."
)
raise ImportError(msg)
convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
elif self._model_type == "bert":
try: try:
from transformers.convert_bert_original_tf_checkpoint_to_pytorch import ( from transformers.convert_bert_original_tf_checkpoint_to_pytorch import (
convert_tf_checkpoint_to_pytorch, convert_tf_checkpoint_to_pytorch,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment