"git@developer.sourcefind.cn:chenpangpang/ComfyUI.git" did not exist on "7d5d0fd577fd6c4ad5de8d07a71aa7599c457b70"
Unverified Commit f75bf05c authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #2352 from huggingface/cli_tweaks

Cli tweaks
parents bfe870be 0d467fd6
...@@ -499,7 +499,7 @@ model = AutoModel.from_pretrained("username/pretrained_model") ...@@ -499,7 +499,7 @@ model = AutoModel.from_pretrained("username/pretrained_model")
Finally, list all your files on S3: Finally, list all your files on S3:
```shell ```shell
transformers-cli ls transformers-cli s3 ls
# List all your S3 objects. # List all your S3 objects.
``` ```
......
...@@ -34,7 +34,7 @@ model = AutoModel.from_pretrained("username/pretrained_model") ...@@ -34,7 +34,7 @@ model = AutoModel.from_pretrained("username/pretrained_model")
Finally, list all your files on S3: Finally, list all your files on S3:
```shell ```shell
transformers-cli ls transformers-cli s3 ls
# List all your S3 objects. # List all your S3 objects.
``` ```
# coding: utf8
def main():
import sys
if len(sys.argv) < 2 or sys.argv[1] not in ["convert", "train", "predict", "serve"]:
print(
"First argument to `transformers` command line interface should be one of: \n"
">> convert serve train predict"
)
if sys.argv[1] == "convert":
from transformers.commands import convert
convert(sys.argv)
elif sys.argv[1] == "train":
from transformers.commands import train
train(sys.argv)
elif sys.argv[1] == "serve":
pass
# from argparse import ArgumentParser
# from transformers.commands.serving import ServeCommand
# parser = ArgumentParser('Transformers CLI tool', usage='transformers serve <command> [<args>]')
# commands_parser = parser.add_subparsers(help='transformers-cli command helpers')
# # Register commands
# ServeCommand.register_subcommand(commands_parser)
# # Let's go
# args = parser.parse_args()
# if not hasattr(args, 'func'):
# parser.print_help()
# exit(1)
# # Run
# service = args.func(args)
# service.run()
if __name__ == "__main__":
main()
...@@ -25,7 +25,7 @@ class ConvertCommand(BaseTransformersCLICommand): ...@@ -25,7 +25,7 @@ class ConvertCommand(BaseTransformersCLICommand):
train_parser = parser.add_parser( train_parser = parser.add_parser(
"convert", "convert",
help="CLI tool to run convert model from original " help="CLI tool to run convert model from original "
"author checkpoints to Transformesr PyTorch checkpoints.", "author checkpoints to Transformers PyTorch checkpoints.",
) )
train_parser.add_argument("--model_type", type=str, required=True, help="Model's type.") train_parser.add_argument("--model_type", type=str, required=True, help="Model's type.")
train_parser.add_argument( train_parser.add_argument(
......
...@@ -9,17 +9,26 @@ from transformers.commands import BaseTransformersCLICommand ...@@ -9,17 +9,26 @@ from transformers.commands import BaseTransformersCLICommand
from transformers.hf_api import HfApi, HfFolder from transformers.hf_api import HfApi, HfFolder
UPLOAD_MAX_FILES = 15
class UserCommands(BaseTransformersCLICommand): class UserCommands(BaseTransformersCLICommand):
@staticmethod @staticmethod
def register_subcommand(parser: ArgumentParser): def register_subcommand(parser: ArgumentParser):
login_parser = parser.add_parser("login") login_parser = parser.add_parser("login", help="Log in using the same credentials as on huggingface.co")
login_parser.set_defaults(func=lambda args: LoginCommand(args)) login_parser.set_defaults(func=lambda args: LoginCommand(args))
whoami_parser = parser.add_parser("whoami") whoami_parser = parser.add_parser("whoami", help="Find out which huggingface.co account you are logged in as.")
whoami_parser.set_defaults(func=lambda args: WhoamiCommand(args)) whoami_parser.set_defaults(func=lambda args: WhoamiCommand(args))
logout_parser = parser.add_parser("logout") logout_parser = parser.add_parser("logout", help="Log out")
logout_parser.set_defaults(func=lambda args: LogoutCommand(args)) logout_parser.set_defaults(func=lambda args: LogoutCommand(args))
list_parser = parser.add_parser("ls") # s3
list_parser.set_defaults(func=lambda args: ListObjsCommand(args)) s3_parser = parser.add_parser("s3", help="{ls, rm} Commands to interact with the files you upload on S3.")
s3_subparsers = s3_parser.add_subparsers(help="s3 related commands")
ls_parser = s3_subparsers.add_parser("ls")
ls_parser.set_defaults(func=lambda args: ListObjsCommand(args))
rm_parser = s3_subparsers.add_parser("rm")
rm_parser.add_argument("filename", type=str, help="individual object filename to delete from S3.")
rm_parser.set_defaults(func=lambda args: DeleteObjCommand(args))
# upload # upload
upload_parser = parser.add_parser("upload") upload_parser = parser.add_parser("upload")
upload_parser.add_argument("path", type=str, help="Local path of the folder or individual file to upload.") upload_parser.add_argument("path", type=str, help="Local path of the folder or individual file to upload.")
...@@ -131,13 +140,27 @@ class ListObjsCommand(BaseUserCommand): ...@@ -131,13 +140,27 @@ class ListObjsCommand(BaseUserCommand):
print(self.tabulate(rows, headers=["Filename", "LastModified", "ETag", "Size"])) print(self.tabulate(rows, headers=["Filename", "LastModified", "ETag", "Size"]))
class DeleteObjCommand(BaseUserCommand):
def run(self):
token = HfFolder.get_token()
if token is None:
print("Not logged in")
exit(1)
try:
self._api.delete_obj(token, filename=self.args.filename)
except HTTPError as e:
print(e)
exit(1)
print("Done")
class UploadCommand(BaseUserCommand): class UploadCommand(BaseUserCommand):
def walk_dir(self, rel_path): def walk_dir(self, rel_path):
""" """
Recursively list all files in a folder. Recursively list all files in a folder.
""" """
entries: List[os.DirEntry] = list(os.scandir(rel_path)) entries: List[os.DirEntry] = list(os.scandir(rel_path))
files = [(os.path.join(os.getcwd(), f.path), f.path) for f in entries if f.is_file()] # filepath # filename files = [(os.path.join(os.getcwd(), f.path), f.path) for f in entries if f.is_file()] # (filepath, filename)
for f in entries: for f in entries:
if f.is_dir(): if f.is_dir():
files += self.walk_dir(f.path) files += self.walk_dir(f.path)
...@@ -160,6 +183,14 @@ class UploadCommand(BaseUserCommand): ...@@ -160,6 +183,14 @@ class UploadCommand(BaseUserCommand):
else: else:
raise ValueError("Not a valid file or directory: {}".format(local_path)) raise ValueError("Not a valid file or directory: {}".format(local_path))
if len(files) > UPLOAD_MAX_FILES:
print(
"About to upload {} files to S3. This is probably wrong. Please filter files before uploading.".format(
ANSI.bold(len(files))
)
)
exit(1)
for filepath, filename in files: for filepath, filename in files:
print("About to upload file {} to S3 under filename {}".format(ANSI.bold(filepath), ANSI.bold(filename))) print("About to upload file {} to S3 under filename {}".format(ANSI.bold(filepath), ANSI.bold(filename)))
......
...@@ -79,7 +79,7 @@ class HfApi: ...@@ -79,7 +79,7 @@ class HfApi:
r = requests.post(path, headers={"authorization": "Bearer {}".format(token)}) r = requests.post(path, headers={"authorization": "Bearer {}".format(token)})
r.raise_for_status() r.raise_for_status()
def presign(self, token: str, filename) -> PresignedUrl: def presign(self, token: str, filename: str) -> PresignedUrl:
""" """
Call HF API to get a presigned url to upload `filename` to S3. Call HF API to get a presigned url to upload `filename` to S3.
""" """
...@@ -89,7 +89,7 @@ class HfApi: ...@@ -89,7 +89,7 @@ class HfApi:
d = r.json() d = r.json()
return PresignedUrl(**d) return PresignedUrl(**d)
def presign_and_upload(self, token: str, filename, filepath) -> str: def presign_and_upload(self, token: str, filename: str, filepath: str) -> str:
""" """
Get a presigned url, then upload file to S3. Get a presigned url, then upload file to S3.
...@@ -111,7 +111,7 @@ class HfApi: ...@@ -111,7 +111,7 @@ class HfApi:
pf.close() pf.close()
return urls.access return urls.access
def list_objs(self, token) -> List[S3Obj]: def list_objs(self, token: str) -> List[S3Obj]:
""" """
Call HF API to list all stored files for user. Call HF API to list all stored files for user.
""" """
...@@ -121,6 +121,14 @@ class HfApi: ...@@ -121,6 +121,14 @@ class HfApi:
d = r.json() d = r.json()
return [S3Obj(**x) for x in d] return [S3Obj(**x) for x in d]
def delete_obj(self, token: str, filename: str):
"""
Call HF API to delete a file stored by user
"""
path = "{}/api/deleteObj".format(self.endpoint)
r = requests.delete(path, headers={"authorization": "Bearer {}".format(token)}, json={"filename": filename})
r.raise_for_status()
class TqdmProgressFileReader: class TqdmProgressFileReader:
""" """
......
...@@ -56,7 +56,7 @@ You can then finish the addition step by adding imports for your classes in the ...@@ -56,7 +56,7 @@ You can then finish the addition step by adding imports for your classes in the
- [ ] add your PyTorch and TF 2.0 model respectively in `modeling_auto.py` and `modeling_tf_auto.py` - [ ] add your PyTorch and TF 2.0 model respectively in `modeling_auto.py` and `modeling_tf_auto.py`
- [ ] add your tokenizer in `tokenization_auto.py` - [ ] add your tokenizer in `tokenization_auto.py`
- [ ] add your models and tokenizer to `pipeline.py` - [ ] add your models and tokenizer to `pipeline.py`
- [ ] add a link to your conversion script in the main conversion utility (currently in `__main__` but will be moved to the `commands` subfolder in the near future) - [ ] add a link to your conversion script in the main conversion utility (in `commands/convert.py`)
- [ ] edit the PyTorch to TF 2.0 conversion script to add your model in the `convert_pytorch_checkpoint_to_tf2.py` file - [ ] edit the PyTorch to TF 2.0 conversion script to add your model in the `convert_pytorch_checkpoint_to_tf2.py` file
- [ ] add a mention of your model in the doc: `README.md` and the documentation itself at `docs/source/pretrained_models.rst`. - [ ] add a mention of your model in the doc: `README.md` and the documentation itself at `docs/source/pretrained_models.rst`.
- [ ] upload the pretrained weigths, configurations and vocabulary files. - [ ] upload the pretrained weigths, configurations and vocabulary files.
...@@ -60,6 +60,11 @@ class HfApiEndpointsTest(HfApiCommonTest): ...@@ -60,6 +60,11 @@ class HfApiEndpointsTest(HfApiCommonTest):
""" """
cls._token = cls._api.login(username=USER, password=PASS) cls._token = cls._api.login(username=USER, password=PASS)
@classmethod
def tearDownClass(cls):
for FILE_KEY, FILE_PATH in FILES:
cls._api.delete_obj(token=cls._token, filename=FILE_KEY)
def test_whoami(self): def test_whoami(self):
user = self._api.whoami(token=self._token) user = self._api.whoami(token=self._token)
self.assertEqual(user, USER) self.assertEqual(user, USER)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment