Unverified Commit f75bf05c authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #2352 from huggingface/cli_tweaks

Cli tweaks
parents bfe870be 0d467fd6
......@@ -499,7 +499,7 @@ model = AutoModel.from_pretrained("username/pretrained_model")
Finally, list all your files on S3:
```shell
transformers-cli ls
transformers-cli s3 ls
# List all your S3 objects.
```
......
......@@ -34,7 +34,7 @@ model = AutoModel.from_pretrained("username/pretrained_model")
Finally, list all your files on S3:
```shell
transformers-cli ls
transformers-cli s3 ls
# List all your S3 objects.
```
# coding: utf8
def main():
import sys
if len(sys.argv) < 2 or sys.argv[1] not in ["convert", "train", "predict", "serve"]:
print(
"First argument to `transformers` command line interface should be one of: \n"
">> convert serve train predict"
)
if sys.argv[1] == "convert":
from transformers.commands import convert
convert(sys.argv)
elif sys.argv[1] == "train":
from transformers.commands import train
train(sys.argv)
elif sys.argv[1] == "serve":
pass
# from argparse import ArgumentParser
# from transformers.commands.serving import ServeCommand
# parser = ArgumentParser('Transformers CLI tool', usage='transformers serve <command> [<args>]')
# commands_parser = parser.add_subparsers(help='transformers-cli command helpers')
# # Register commands
# ServeCommand.register_subcommand(commands_parser)
# # Let's go
# args = parser.parse_args()
# if not hasattr(args, 'func'):
# parser.print_help()
# exit(1)
# # Run
# service = args.func(args)
# service.run()
if __name__ == "__main__":
main()
......@@ -25,7 +25,7 @@ class ConvertCommand(BaseTransformersCLICommand):
train_parser = parser.add_parser(
"convert",
help="CLI tool to run convert model from original "
"author checkpoints to Transformesr PyTorch checkpoints.",
"author checkpoints to Transformers PyTorch checkpoints.",
)
train_parser.add_argument("--model_type", type=str, required=True, help="Model's type.")
train_parser.add_argument(
......
......@@ -9,17 +9,26 @@ from transformers.commands import BaseTransformersCLICommand
from transformers.hf_api import HfApi, HfFolder
UPLOAD_MAX_FILES = 15
class UserCommands(BaseTransformersCLICommand):
@staticmethod
def register_subcommand(parser: ArgumentParser):
login_parser = parser.add_parser("login")
login_parser = parser.add_parser("login", help="Log in using the same credentials as on huggingface.co")
login_parser.set_defaults(func=lambda args: LoginCommand(args))
whoami_parser = parser.add_parser("whoami")
whoami_parser = parser.add_parser("whoami", help="Find out which huggingface.co account you are logged in as.")
whoami_parser.set_defaults(func=lambda args: WhoamiCommand(args))
logout_parser = parser.add_parser("logout")
logout_parser = parser.add_parser("logout", help="Log out")
logout_parser.set_defaults(func=lambda args: LogoutCommand(args))
list_parser = parser.add_parser("ls")
list_parser.set_defaults(func=lambda args: ListObjsCommand(args))
# s3
s3_parser = parser.add_parser("s3", help="{ls, rm} Commands to interact with the files you upload on S3.")
s3_subparsers = s3_parser.add_subparsers(help="s3 related commands")
ls_parser = s3_subparsers.add_parser("ls")
ls_parser.set_defaults(func=lambda args: ListObjsCommand(args))
rm_parser = s3_subparsers.add_parser("rm")
rm_parser.add_argument("filename", type=str, help="individual object filename to delete from S3.")
rm_parser.set_defaults(func=lambda args: DeleteObjCommand(args))
# upload
upload_parser = parser.add_parser("upload")
upload_parser.add_argument("path", type=str, help="Local path of the folder or individual file to upload.")
......@@ -131,13 +140,27 @@ class ListObjsCommand(BaseUserCommand):
print(self.tabulate(rows, headers=["Filename", "LastModified", "ETag", "Size"]))
class DeleteObjCommand(BaseUserCommand):
def run(self):
token = HfFolder.get_token()
if token is None:
print("Not logged in")
exit(1)
try:
self._api.delete_obj(token, filename=self.args.filename)
except HTTPError as e:
print(e)
exit(1)
print("Done")
class UploadCommand(BaseUserCommand):
def walk_dir(self, rel_path):
"""
Recursively list all files in a folder.
"""
entries: List[os.DirEntry] = list(os.scandir(rel_path))
files = [(os.path.join(os.getcwd(), f.path), f.path) for f in entries if f.is_file()] # filepath # filename
files = [(os.path.join(os.getcwd(), f.path), f.path) for f in entries if f.is_file()] # (filepath, filename)
for f in entries:
if f.is_dir():
files += self.walk_dir(f.path)
......@@ -160,6 +183,14 @@ class UploadCommand(BaseUserCommand):
else:
raise ValueError("Not a valid file or directory: {}".format(local_path))
if len(files) > UPLOAD_MAX_FILES:
print(
"About to upload {} files to S3. This is probably wrong. Please filter files before uploading.".format(
ANSI.bold(len(files))
)
)
exit(1)
for filepath, filename in files:
print("About to upload file {} to S3 under filename {}".format(ANSI.bold(filepath), ANSI.bold(filename)))
......
......@@ -79,7 +79,7 @@ class HfApi:
r = requests.post(path, headers={"authorization": "Bearer {}".format(token)})
r.raise_for_status()
def presign(self, token: str, filename) -> PresignedUrl:
def presign(self, token: str, filename: str) -> PresignedUrl:
"""
Call HF API to get a presigned url to upload `filename` to S3.
"""
......@@ -89,7 +89,7 @@ class HfApi:
d = r.json()
return PresignedUrl(**d)
def presign_and_upload(self, token: str, filename, filepath) -> str:
def presign_and_upload(self, token: str, filename: str, filepath: str) -> str:
"""
Get a presigned url, then upload file to S3.
......@@ -111,7 +111,7 @@ class HfApi:
pf.close()
return urls.access
def list_objs(self, token) -> List[S3Obj]:
def list_objs(self, token: str) -> List[S3Obj]:
"""
Call HF API to list all stored files for user.
"""
......@@ -121,6 +121,14 @@ class HfApi:
d = r.json()
return [S3Obj(**x) for x in d]
def delete_obj(self, token: str, filename: str):
"""
Call HF API to delete a file stored by user
"""
path = "{}/api/deleteObj".format(self.endpoint)
r = requests.delete(path, headers={"authorization": "Bearer {}".format(token)}, json={"filename": filename})
r.raise_for_status()
class TqdmProgressFileReader:
"""
......
......@@ -56,7 +56,7 @@ You can then finish the addition step by adding imports for your classes in the
- [ ] add your PyTorch and TF 2.0 model respectively in `modeling_auto.py` and `modeling_tf_auto.py`
- [ ] add your tokenizer in `tokenization_auto.py`
- [ ] add your models and tokenizer to `pipeline.py`
- [ ] add a link to your conversion script in the main conversion utility (currently in `__main__` but will be moved to the `commands` subfolder in the near future)
- [ ] add a link to your conversion script in the main conversion utility (in `commands/convert.py`)
- [ ] edit the PyTorch to TF 2.0 conversion script to add your model in the `convert_pytorch_checkpoint_to_tf2.py` file
- [ ] add a mention of your model in the doc: `README.md` and the documentation itself at `docs/source/pretrained_models.rst`.
- [ ] upload the pretrained weigths, configurations and vocabulary files.
......@@ -60,6 +60,11 @@ class HfApiEndpointsTest(HfApiCommonTest):
"""
cls._token = cls._api.login(username=USER, password=PASS)
@classmethod
def tearDownClass(cls):
for FILE_KEY, FILE_PATH in FILES:
cls._api.delete_obj(token=cls._token, filename=FILE_KEY)
def test_whoami(self):
user = self._api.whoami(token=self._token)
self.assertEqual(user, USER)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment