"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "f71a4577b8a54d4a19fb1b1ca8ba15cfc6b5bb6e"
Commit 81babb22 authored by Morgan Funtowicz's avatar Morgan Funtowicz
Browse files

Added download command through the cli.

It allows to predownload models and tokenizers.
parent 31a3a73e
#!/usr/bin/env python #!/usr/bin/env python
from argparse import ArgumentParser from argparse import ArgumentParser
from transformers.commands.download import DownloadCommand
from transformers.commands.serving import ServeCommand from transformers.commands.serving import ServeCommand
from transformers.commands.user import UserCommands from transformers.commands.user import UserCommands
from transformers.commands.train import TrainCommand from transformers.commands.train import TrainCommand
...@@ -11,10 +12,11 @@ if __name__ == '__main__': ...@@ -11,10 +12,11 @@ if __name__ == '__main__':
commands_parser = parser.add_subparsers(help='transformers-cli command helpers') commands_parser = parser.add_subparsers(help='transformers-cli command helpers')
# Register commands # Register commands
ConvertCommand.register_subcommand(commands_parser)
DownloadCommand.register_subcommand(commands_parser)
ServeCommand.register_subcommand(commands_parser) ServeCommand.register_subcommand(commands_parser)
UserCommands.register_subcommand(commands_parser) UserCommands.register_subcommand(commands_parser)
TrainCommand.register_subcommand(commands_parser) TrainCommand.register_subcommand(commands_parser)
ConvertCommand.register_subcommand(commands_parser)
# Let's go # Let's go
args = parser.parse_args() args = parser.parse_args()
......
from argparse import ArgumentParser
from transformers.commands import BaseTransformersCLICommand
def download_command_factory(args):
return DownloadCommand(args.model, args.cache_dir, args.force)
class DownloadCommand(BaseTransformersCLICommand):
@staticmethod
def register_subcommand(parser: ArgumentParser):
download_parser = parser.add_parser('download')
download_parser.add_argument('--cache-dir', type=str, default=None, help='Path to location to store the models')
download_parser.add_argument('--force', action='store_true', help='Force the model to be download even if already in cache-dir')
download_parser.add_argument('model', type=str, help='Name of the model to download')
download_parser.set_defaults(func=download_command_factory)
def __init__(self, model: str, cache: str, force: bool):
self._model = model
self._cache = cache
self._force = force
def run(self):
from transformers import AutoModel, AutoTokenizer
AutoModel.from_pretrained(self._model, cache_dir=self._cache, force_download=self._force)
AutoTokenizer.from_pretrained(self._model, cache_dir=self._cache, force_download=self._force)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment