Commit 81babb22 authored by Morgan Funtowicz's avatar Morgan Funtowicz
Browse files

Added download command through the cli.

It allows to predownload models and tokenizers.
parent 31a3a73e
#!/usr/bin/env python
from argparse import ArgumentParser
from transformers.commands.download import DownloadCommand
from transformers.commands.serving import ServeCommand
from transformers.commands.user import UserCommands
from transformers.commands.train import TrainCommand
......@@ -11,10 +12,11 @@ if __name__ == '__main__':
commands_parser = parser.add_subparsers(help='transformers-cli command helpers')
# Register commands
ConvertCommand.register_subcommand(commands_parser)
DownloadCommand.register_subcommand(commands_parser)
ServeCommand.register_subcommand(commands_parser)
UserCommands.register_subcommand(commands_parser)
TrainCommand.register_subcommand(commands_parser)
ConvertCommand.register_subcommand(commands_parser)
# Let's go
args = parser.parse_args()
......
from argparse import ArgumentParser
from transformers.commands import BaseTransformersCLICommand
def download_command_factory(args):
return DownloadCommand(args.model, args.cache_dir, args.force)
class DownloadCommand(BaseTransformersCLICommand):
@staticmethod
def register_subcommand(parser: ArgumentParser):
download_parser = parser.add_parser('download')
download_parser.add_argument('--cache-dir', type=str, default=None, help='Path to location to store the models')
download_parser.add_argument('--force', action='store_true', help='Force the model to be download even if already in cache-dir')
download_parser.add_argument('model', type=str, help='Name of the model to download')
download_parser.set_defaults(func=download_command_factory)
def __init__(self, model: str, cache: str, force: bool):
self._model = model
self._cache = cache
self._force = force
def run(self):
from transformers import AutoModel, AutoTokenizer
AutoModel.from_pretrained(self._model, cache_dir=self._cache, force_download=self._force)
AutoTokenizer.from_pretrained(self._model, cache_dir=self._cache, force_download=self._force)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment