Commit f78ebc22 authored by Julien Chaumond's avatar Julien Chaumond
Browse files

[cli] Add ability to delete remote object

parent bfe870be
...@@ -25,7 +25,7 @@ class ConvertCommand(BaseTransformersCLICommand): ...@@ -25,7 +25,7 @@ class ConvertCommand(BaseTransformersCLICommand):
train_parser = parser.add_parser( train_parser = parser.add_parser(
"convert", "convert",
help="CLI tool to run convert model from original " help="CLI tool to run convert model from original "
"author checkpoints to Transformesr PyTorch checkpoints.", "author checkpoints to Transformers PyTorch checkpoints.",
) )
train_parser.add_argument("--model_type", type=str, required=True, help="Model's type.") train_parser.add_argument("--model_type", type=str, required=True, help="Model's type.")
train_parser.add_argument( train_parser.add_argument(
......
...@@ -12,14 +12,20 @@ from transformers.hf_api import HfApi, HfFolder ...@@ -12,14 +12,20 @@ from transformers.hf_api import HfApi, HfFolder
class UserCommands(BaseTransformersCLICommand): class UserCommands(BaseTransformersCLICommand):
@staticmethod @staticmethod
def register_subcommand(parser: ArgumentParser): def register_subcommand(parser: ArgumentParser):
login_parser = parser.add_parser("login") login_parser = parser.add_parser("login", help="Log in using the same credentials as on huggingface.co")
login_parser.set_defaults(func=lambda args: LoginCommand(args)) login_parser.set_defaults(func=lambda args: LoginCommand(args))
whoami_parser = parser.add_parser("whoami") whoami_parser = parser.add_parser("whoami", help="Find out which huggingface.co account you are logged in as.")
whoami_parser.set_defaults(func=lambda args: WhoamiCommand(args)) whoami_parser.set_defaults(func=lambda args: WhoamiCommand(args))
logout_parser = parser.add_parser("logout") logout_parser = parser.add_parser("logout", help="Log out")
logout_parser.set_defaults(func=lambda args: LogoutCommand(args)) logout_parser.set_defaults(func=lambda args: LogoutCommand(args))
list_parser = parser.add_parser("ls") # s3
list_parser.set_defaults(func=lambda args: ListObjsCommand(args)) s3_parser = parser.add_parser("s3", help="{ls, rm} Commands to interact with the files you upload on S3.")
s3_subparsers = s3_parser.add_subparsers(help="s3 related commands")
ls_parser = s3_subparsers.add_parser("ls")
ls_parser.set_defaults(func=lambda args: ListObjsCommand(args))
rm_parser = s3_subparsers.add_parser("rm")
rm_parser.add_argument("filename", type=str, help="individual object filename to delete from S3.")
rm_parser.set_defaults(func=lambda args: DeleteObjCommand(args))
# upload # upload
upload_parser = parser.add_parser("upload") upload_parser = parser.add_parser("upload")
upload_parser.add_argument("path", type=str, help="Local path of the folder or individual file to upload.") upload_parser.add_argument("path", type=str, help="Local path of the folder or individual file to upload.")
...@@ -131,13 +137,27 @@ class ListObjsCommand(BaseUserCommand): ...@@ -131,13 +137,27 @@ class ListObjsCommand(BaseUserCommand):
print(self.tabulate(rows, headers=["Filename", "LastModified", "ETag", "Size"])) print(self.tabulate(rows, headers=["Filename", "LastModified", "ETag", "Size"]))
class DeleteObjCommand(BaseUserCommand):
def run(self):
token = HfFolder.get_token()
if token is None:
print("Not logged in")
exit(1)
try:
self._api.delete_obj(token, filename=self.args.filename)
except HTTPError as a:
print(e)
exit(1)
print("Done")
class UploadCommand(BaseUserCommand): class UploadCommand(BaseUserCommand):
def walk_dir(self, rel_path): def walk_dir(self, rel_path):
""" """
Recursively list all files in a folder. Recursively list all files in a folder.
""" """
entries: List[os.DirEntry] = list(os.scandir(rel_path)) entries: List[os.DirEntry] = list(os.scandir(rel_path))
files = [(os.path.join(os.getcwd(), f.path), f.path) for f in entries if f.is_file()] # filepath # filename files = [(os.path.join(os.getcwd(), f.path), f.path) for f in entries if f.is_file()] # (filepath, filename)
for f in entries: for f in entries:
if f.is_dir(): if f.is_dir():
files += self.walk_dir(f.path) files += self.walk_dir(f.path)
......
...@@ -79,7 +79,7 @@ class HfApi: ...@@ -79,7 +79,7 @@ class HfApi:
r = requests.post(path, headers={"authorization": "Bearer {}".format(token)}) r = requests.post(path, headers={"authorization": "Bearer {}".format(token)})
r.raise_for_status() r.raise_for_status()
def presign(self, token: str, filename) -> PresignedUrl: def presign(self, token: str, filename: str) -> PresignedUrl:
""" """
Call HF API to get a presigned url to upload `filename` to S3. Call HF API to get a presigned url to upload `filename` to S3.
""" """
...@@ -89,7 +89,7 @@ class HfApi: ...@@ -89,7 +89,7 @@ class HfApi:
d = r.json() d = r.json()
return PresignedUrl(**d) return PresignedUrl(**d)
def presign_and_upload(self, token: str, filename, filepath) -> str: def presign_and_upload(self, token: str, filename: str, filepath: str) -> str:
""" """
Get a presigned url, then upload file to S3. Get a presigned url, then upload file to S3.
...@@ -111,7 +111,7 @@ class HfApi: ...@@ -111,7 +111,7 @@ class HfApi:
pf.close() pf.close()
return urls.access return urls.access
def list_objs(self, token) -> List[S3Obj]: def list_objs(self, token: str) -> List[S3Obj]:
""" """
Call HF API to list all stored files for user. Call HF API to list all stored files for user.
""" """
...@@ -121,6 +121,14 @@ class HfApi: ...@@ -121,6 +121,14 @@ class HfApi:
d = r.json() d = r.json()
return [S3Obj(**x) for x in d] return [S3Obj(**x) for x in d]
def delete_obj(self, token: str, filename: str):
"""
Call HF API to delete a file stored by user
"""
path = "{}/api/deleteObj".format(self.endpoint)
r = requests.delete(path, headers={"authorization": "Bearer {}".format(token)}, json={"filename": filename})
r.raise_for_status()
class TqdmProgressFileReader: class TqdmProgressFileReader:
""" """
......
...@@ -60,6 +60,11 @@ class HfApiEndpointsTest(HfApiCommonTest): ...@@ -60,6 +60,11 @@ class HfApiEndpointsTest(HfApiCommonTest):
""" """
cls._token = cls._api.login(username=USER, password=PASS) cls._token = cls._api.login(username=USER, password=PASS)
@classmethod
def tearDownClass(cls):
for FILE_KEY, FILE_PATH in FILES:
cls._api.delete_obj(token=cls._token, filename=FILE_KEY)
def test_whoami(self): def test_whoami(self):
user = self._api.whoami(token=self._token) user = self._api.whoami(token=self._token)
self.assertEqual(user, USER) self.assertEqual(user, USER)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment