"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "b191d7db442fcd4ea794768363155c091e60b868"
Unverified Commit 8940c766 authored by acul3's avatar acul3 Committed by GitHub
Browse files

Add t5 convert to transformers-cli (#9654)

* Update run_mlm.py

* add t5 model to transformers-cli convert

* update rum_mlm.py same as master

* update converting model docs

* update converting model docs

* Update convert.py

* Trigger notification

* update import sorted

* fix typo t5
parent 7251a473
...@@ -168,3 +168,18 @@ Here is an example of the conversion process for a pre-trained XLM model: ...@@ -168,3 +168,18 @@ Here is an example of the conversion process for a pre-trained XLM model:
--pytorch_dump_output $PYTORCH_DUMP_OUTPUT --pytorch_dump_output $PYTORCH_DUMP_OUTPUT
[--config XML_CONFIG] \ [--config XML_CONFIG] \
[--finetuning_task_name XML_FINETUNED_TASK] [--finetuning_task_name XML_FINETUNED_TASK]
T5
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Here is an example of the conversion process for a pre-trained T5 model:
.. code-block:: shell
export T5=/path/to/t5/uncased_L-12_H-768_A-12
transformers-cli convert --model_type t5 \
--tf_checkpoint $T5/t5_model.ckpt \
--config $T5/t5_config.json \
--pytorch_dump_output $T5/pytorch_model.bin
...@@ -110,6 +110,13 @@ class ConvertCommand(BaseTransformersCLICommand): ...@@ -110,6 +110,13 @@ class ConvertCommand(BaseTransformersCLICommand):
except ImportError: except ImportError:
raise ImportError(IMPORT_ERROR_MESSAGE) raise ImportError(IMPORT_ERROR_MESSAGE)
convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
elif self._model_type == "t5":
try:
from ..models.t5.convert_t5_original_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch
except ImportError:
raise ImportError(IMPORT_ERROR_MESSAGE)
convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output) convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
elif self._model_type == "gpt": elif self._model_type == "gpt":
from ..models.openai.convert_openai_original_tf_checkpoint_to_pytorch import ( from ..models.openai.convert_openai_original_tf_checkpoint_to_pytorch import (
...@@ -168,5 +175,5 @@ class ConvertCommand(BaseTransformersCLICommand): ...@@ -168,5 +175,5 @@ class ConvertCommand(BaseTransformersCLICommand):
convert_lxmert_checkpoint_to_pytorch(self._tf_checkpoint, self._pytorch_dump_output) convert_lxmert_checkpoint_to_pytorch(self._tf_checkpoint, self._pytorch_dump_output)
else: else:
raise ValueError( raise ValueError(
"--model_type should be selected in the list [bert, gpt, gpt2, transfo_xl, xlnet, xlm, lxmert]" "--model_type should be selected in the list [bert, gpt, gpt2, t5, transfo_xl, xlnet, xlm, lxmert]"
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment