converting_tensorflow_models.rst 7.83 KB
Newer Older
Sylvain Gugger's avatar
Sylvain Gugger committed
1
2
3
4
5
6
7
8
9
10
11
12
.. 
    Copyright 2020 The HuggingFace Team. All rights reserved.

    Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
    the License. You may obtain a copy of the License at

        http://www.apache.org/licenses/LICENSE-2.0

    Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
    an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
    specific language governing permissions and limitations under the License.

thomwolf's avatar
thomwolf committed
13
Converting Tensorflow Checkpoints
Sylvain Gugger's avatar
Sylvain Gugger committed
14
=======================================================================================================================
15

16
17
A command-line interface is provided to convert original Bert/GPT/GPT-2/Transformer-XL/XLNet/XLM checkpoints to models
that can be loaded using the ``from_pretrained`` methods of the library.
18

19
.. note::
Sylvain Gugger's avatar
Sylvain Gugger committed
20
21
    Since 2.3.0 the conversion script is now part of the transformers CLI (**transformers-cli**) available in any
    transformers >= 2.3.0 installation.
22
23
24

    The documentation below reflects the **transformers-cli convert** command format.

25
BERT
Sylvain Gugger's avatar
Sylvain Gugger committed
26
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
27

Sylvain Gugger's avatar
Sylvain Gugger committed
28
You can convert any TensorFlow checkpoint for BERT (in particular `the pre-trained models released by Google
Sylvain Gugger's avatar
Sylvain Gugger committed
29
<https://github.com/google-research/bert#pre-trained-models>`_) in a PyTorch save file by using the
30
:prefix_link:`convert_bert_original_tf_checkpoint_to_pytorch.py
31
<src/transformers/models/bert/convert_bert_original_tf_checkpoint_to_pytorch.py>` script.
Sylvain Gugger's avatar
Sylvain Gugger committed
32

Sylvain Gugger's avatar
Sylvain Gugger committed
33
34
35
36
37
This CLI takes as input a TensorFlow checkpoint (three files starting with ``bert_model.ckpt``) and the associated
configuration file (``bert_config.json``), and creates a PyTorch model for this configuration, loads the weights from
the TensorFlow checkpoint in the PyTorch model and saves the resulting model in a standard PyTorch save file that can
be imported using ``from_pretrained()`` (see example in :doc:`quicktour` , :prefix_link:`run_glue.py
<examples/pytorch/text-classification/run_glue.py>` ).
Sylvain Gugger's avatar
Sylvain Gugger committed
38
39

You only need to run this conversion script **once** to get a PyTorch model. You can then disregard the TensorFlow
Sylvain Gugger's avatar
Sylvain Gugger committed
40
41
checkpoint (the three files starting with ``bert_model.ckpt``) but be sure to keep the configuration file (\
``bert_config.json``) and the vocabulary file (``vocab.txt``) as these are needed for the PyTorch model too.
Sylvain Gugger's avatar
Sylvain Gugger committed
42

Sylvain Gugger's avatar
Sylvain Gugger committed
43
44
To run this specific conversion script you will need to have TensorFlow and PyTorch installed (``pip install
tensorflow``). The rest of the repository only requires PyTorch.
45
46
47
48
49

Here is an example of the conversion process for a pre-trained ``BERT-Base Uncased`` model:

.. code-block:: shell

50
    export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12
51

52
53
54
55
    transformers-cli convert --model_type bert \
      --tf_checkpoint $BERT_BASE_DIR/bert_model.ckpt \
      --config $BERT_BASE_DIR/bert_config.json \
      --pytorch_dump_output $BERT_BASE_DIR/pytorch_model.bin
56

Sylvain Gugger's avatar
Sylvain Gugger committed
57
58
You can download Google's pre-trained models for the conversion `here
<https://github.com/google-research/bert#pre-trained-models>`__.
59

60
ALBERT
Sylvain Gugger's avatar
Sylvain Gugger committed
61
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
62

Sylvain Gugger's avatar
Sylvain Gugger committed
63
Convert TensorFlow model checkpoints of ALBERT to PyTorch using the
64
:prefix_link:`convert_albert_original_tf_checkpoint_to_pytorch.py
65
<src/transformers/models/albert/convert_albert_original_tf_checkpoint_to_pytorch.py>` script.
66

Sylvain Gugger's avatar
Sylvain Gugger committed
67
68
69
The CLI takes as input a TensorFlow checkpoint (three files starting with ``model.ckpt-best``) and the accompanying
configuration file (``albert_config.json``), then creates and saves a PyTorch model. To run this conversion you will
need to have TensorFlow and PyTorch installed.
70
71
72
73
74

Here is an example of the conversion process for the pre-trained ``ALBERT Base`` model:

.. code-block:: shell

75
    export ALBERT_BASE_DIR=/path/to/albert/albert_base
76

77
78
79
80
    transformers-cli convert --model_type albert \
      --tf_checkpoint $ALBERT_BASE_DIR/model.ckpt-best \
      --config $ALBERT_BASE_DIR/albert_config.json \
      --pytorch_dump_output $ALBERT_BASE_DIR/pytorch_model.bin
81

Sylvain Gugger's avatar
Sylvain Gugger committed
82
83
You can download Google's pre-trained models for the conversion `here
<https://github.com/google-research/albert#pre-trained-models>`__.
84

85
OpenAI GPT
Sylvain Gugger's avatar
Sylvain Gugger committed
86
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
87

Sylvain Gugger's avatar
Sylvain Gugger committed
88
89
90
Here is an example of the conversion process for a pre-trained OpenAI GPT model, assuming that your NumPy checkpoint
save as the same format than OpenAI pretrained model (see `here <https://github.com/openai/finetune-transformer-lm>`__\
)
91
92
93

.. code-block:: shell

94
    export OPENAI_GPT_CHECKPOINT_FOLDER_PATH=/path/to/openai/pretrained/numpy/weights
95

96
97
98
99
100
    transformers-cli convert --model_type gpt \
      --tf_checkpoint $OPENAI_GPT_CHECKPOINT_FOLDER_PATH \
      --pytorch_dump_output $PYTORCH_DUMP_OUTPUT \
      [--config OPENAI_GPT_CONFIG] \
      [--finetuning_task_name OPENAI_GPT_FINETUNED_TASK] \
101

102

thomwolf's avatar
thomwolf committed
103
OpenAI GPT-2
Sylvain Gugger's avatar
Sylvain Gugger committed
104
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
thomwolf's avatar
thomwolf committed
105

Sylvain Gugger's avatar
Sylvain Gugger committed
106
Here is an example of the conversion process for a pre-trained OpenAI GPT-2 model (see `here
Sylvain Gugger's avatar
Sylvain Gugger committed
107
<https://github.com/openai/gpt-2>`__)
thomwolf's avatar
thomwolf committed
108
109
110

.. code-block:: shell

111
    export OPENAI_GPT2_CHECKPOINT_PATH=/path/to/gpt2/pretrained/weights
thomwolf's avatar
thomwolf committed
112

113
114
115
116
117
    transformers-cli convert --model_type gpt2 \
      --tf_checkpoint $OPENAI_GPT2_CHECKPOINT_PATH \
      --pytorch_dump_output $PYTORCH_DUMP_OUTPUT \
      [--config OPENAI_GPT2_CONFIG] \
      [--finetuning_task_name OPENAI_GPT2_FINETUNED_TASK]
thomwolf's avatar
thomwolf committed
118

119
Transformer-XL
Sylvain Gugger's avatar
Sylvain Gugger committed
120
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
121

Sylvain Gugger's avatar
Sylvain Gugger committed
122
Here is an example of the conversion process for a pre-trained Transformer-XL model (see `here
Sylvain Gugger's avatar
Sylvain Gugger committed
123
<https://github.com/kimiyoung/transformer-xl/tree/master/tf#obtain-and-evaluate-pretrained-sota-models>`__)
124
125
126

.. code-block:: shell

127
    export TRANSFO_XL_CHECKPOINT_FOLDER_PATH=/path/to/transfo/xl/checkpoint
128

129
130
131
132
133
    transformers-cli convert --model_type transfo_xl \
      --tf_checkpoint $TRANSFO_XL_CHECKPOINT_FOLDER_PATH \
      --pytorch_dump_output $PYTORCH_DUMP_OUTPUT \
      [--config TRANSFO_XL_CONFIG] \
      [--finetuning_task_name TRANSFO_XL_FINETUNED_TASK]
134
135
136


XLNet
Sylvain Gugger's avatar
Sylvain Gugger committed
137
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
138

139
Here is an example of the conversion process for a pre-trained XLNet model:
140
141
142

.. code-block:: shell

143
144
    export TRANSFO_XL_CHECKPOINT_PATH=/path/to/xlnet/checkpoint
    export TRANSFO_XL_CONFIG_PATH=/path/to/xlnet/config
145

146
147
148
149
150
    transformers-cli convert --model_type xlnet \
      --tf_checkpoint $TRANSFO_XL_CHECKPOINT_PATH \
      --config $TRANSFO_XL_CONFIG_PATH \
      --pytorch_dump_output $PYTORCH_DUMP_OUTPUT \
      [--finetuning_task_name XLNET_FINETUNED_TASK] \
thomwolf's avatar
thomwolf committed
151
152
153


XLM
Sylvain Gugger's avatar
Sylvain Gugger committed
154
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
thomwolf's avatar
thomwolf committed
155
156
157
158
159

Here is an example of the conversion process for a pre-trained XLM model:

.. code-block:: shell

160
    export XLM_CHECKPOINT_PATH=/path/to/xlm/checkpoint
thomwolf's avatar
thomwolf committed
161

162
163
164
165
166
    transformers-cli convert --model_type xlm \
      --tf_checkpoint $XLM_CHECKPOINT_PATH \
      --pytorch_dump_output $PYTORCH_DUMP_OUTPUT
     [--config XML_CONFIG] \
     [--finetuning_task_name XML_FINETUNED_TASK]
167
168
169
170
171
172
173
174
175


T5
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Here is an example of the conversion process for a pre-trained T5 model:

.. code-block:: shell

176
    export T5=/path/to/t5/uncased_L-12_H-768_A-12
177

178
179
180
181
    transformers-cli convert --model_type t5 \
      --tf_checkpoint $T5/t5_model.ckpt \
      --config $T5/t5_config.json \
      --pytorch_dump_output $T5/pytorch_model.bin