Commit cbf8f5d3 authored by Julien Chaumond's avatar Julien Chaumond
Browse files

[model upload] Support for organizations

parent 525b6b1c
...@@ -26,13 +26,16 @@ class UserCommands(BaseTransformersCLICommand): ...@@ -26,13 +26,16 @@ class UserCommands(BaseTransformersCLICommand):
s3_parser = parser.add_parser("s3", help="{ls, rm} Commands to interact with the files you upload on S3.") s3_parser = parser.add_parser("s3", help="{ls, rm} Commands to interact with the files you upload on S3.")
s3_subparsers = s3_parser.add_subparsers(help="s3 related commands") s3_subparsers = s3_parser.add_subparsers(help="s3 related commands")
ls_parser = s3_subparsers.add_parser("ls") ls_parser = s3_subparsers.add_parser("ls")
ls_parser.add_argument("--organization", type=str, help="Optional: organization namespace.")
ls_parser.set_defaults(func=lambda args: ListObjsCommand(args)) ls_parser.set_defaults(func=lambda args: ListObjsCommand(args))
rm_parser = s3_subparsers.add_parser("rm") rm_parser = s3_subparsers.add_parser("rm")
rm_parser.add_argument("filename", type=str, help="individual object filename to delete from S3.") rm_parser.add_argument("filename", type=str, help="individual object filename to delete from S3.")
rm_parser.add_argument("--organization", type=str, help="Optional: organization namespace.")
rm_parser.set_defaults(func=lambda args: DeleteObjCommand(args)) rm_parser.set_defaults(func=lambda args: DeleteObjCommand(args))
# upload # upload
upload_parser = parser.add_parser("upload") upload_parser = parser.add_parser("upload")
upload_parser.add_argument("path", type=str, help="Local path of the folder or individual file to upload.") upload_parser.add_argument("path", type=str, help="Local path of the folder or individual file to upload.")
upload_parser.add_argument("--organization", type=str, help="Optional: organization namespace.")
upload_parser.add_argument( upload_parser.add_argument(
"--filename", type=str, default=None, help="Optional: override individual object filename on S3." "--filename", type=str, default=None, help="Optional: override individual object filename on S3."
) )
...@@ -91,8 +94,10 @@ class WhoamiCommand(BaseUserCommand): ...@@ -91,8 +94,10 @@ class WhoamiCommand(BaseUserCommand):
print("Not logged in") print("Not logged in")
exit() exit()
try: try:
user = self._api.whoami(token) user, orgs = self._api.whoami(token)
print(user) print(user)
if orgs:
print(ANSI.bold("orgs: "), ",".join(orgs))
except HTTPError as e: except HTTPError as e:
print(e) print(e)
...@@ -130,7 +135,7 @@ class ListObjsCommand(BaseUserCommand): ...@@ -130,7 +135,7 @@ class ListObjsCommand(BaseUserCommand):
print("Not logged in") print("Not logged in")
exit(1) exit(1)
try: try:
objs = self._api.list_objs(token) objs = self._api.list_objs(token, organization=self.args.organization)
except HTTPError as e: except HTTPError as e:
print(e) print(e)
exit(1) exit(1)
...@@ -148,7 +153,7 @@ class DeleteObjCommand(BaseUserCommand): ...@@ -148,7 +153,7 @@ class DeleteObjCommand(BaseUserCommand):
print("Not logged in") print("Not logged in")
exit(1) exit(1)
try: try:
self._api.delete_obj(token, filename=self.args.filename) self._api.delete_obj(token, filename=self.args.filename, organization=self.args.organization)
except HTTPError as e: except HTTPError as e:
print(e) print(e)
exit(1) exit(1)
...@@ -195,8 +200,15 @@ class UploadCommand(BaseUserCommand): ...@@ -195,8 +200,15 @@ class UploadCommand(BaseUserCommand):
) )
exit(1) exit(1)
user, _ = self._api.whoami(token)
namespace = self.args.organization if self.args.organization is not None else user
for filepath, filename in files: for filepath, filename in files:
print("About to upload file {} to S3 under filename {}".format(ANSI.bold(filepath), ANSI.bold(filename))) print(
"About to upload file {} to S3 under filename {} and namespace {}".format(
ANSI.bold(filepath), ANSI.bold(filename), ANSI.bold(namespace)
)
)
choice = input("Proceed? [Y/n] ").lower() choice = input("Proceed? [Y/n] ").lower()
if not (choice == "" or choice == "y" or choice == "yes"): if not (choice == "" or choice == "y" or choice == "yes"):
...@@ -204,6 +216,8 @@ class UploadCommand(BaseUserCommand): ...@@ -204,6 +216,8 @@ class UploadCommand(BaseUserCommand):
exit() exit()
print(ANSI.bold("Uploading... This might take a while if files are large")) print(ANSI.bold("Uploading... This might take a while if files are large"))
for filepath, filename in files: for filepath, filename in files:
access_url = self._api.presign_and_upload(token=token, filename=filename, filepath=filepath) access_url = self._api.presign_and_upload(
token=token, filename=filename, filepath=filepath, organization=self.args.organization
)
print("Your file now lives at:") print("Your file now lives at:")
print(access_url) print(access_url)
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import io import io
import os import os
from os.path import expanduser from os.path import expanduser
from typing import Dict, List, Optional from typing import Dict, List, Optional, Tuple
import requests import requests
from tqdm import tqdm from tqdm import tqdm
...@@ -109,7 +109,7 @@ class HfApi: ...@@ -109,7 +109,7 @@ class HfApi:
d = r.json() d = r.json()
return d["token"] return d["token"]
def whoami(self, token: str) -> str: def whoami(self, token: str) -> Tuple[str, List[str]]:
""" """
Call HF API to know "whoami" Call HF API to know "whoami"
""" """
...@@ -117,7 +117,7 @@ class HfApi: ...@@ -117,7 +117,7 @@ class HfApi:
r = requests.get(path, headers={"authorization": "Bearer {}".format(token)}) r = requests.get(path, headers={"authorization": "Bearer {}".format(token)})
r.raise_for_status() r.raise_for_status()
d = r.json() d = r.json()
return d["user"] return d["user"], d["orgs"]
def logout(self, token: str) -> None: def logout(self, token: str) -> None:
""" """
...@@ -127,24 +127,28 @@ class HfApi: ...@@ -127,24 +127,28 @@ class HfApi:
r = requests.post(path, headers={"authorization": "Bearer {}".format(token)}) r = requests.post(path, headers={"authorization": "Bearer {}".format(token)})
r.raise_for_status() r.raise_for_status()
def presign(self, token: str, filename: str) -> PresignedUrl: def presign(self, token: str, filename: str, organization: Optional[str] = None) -> PresignedUrl:
""" """
Call HF API to get a presigned url to upload `filename` to S3. Call HF API to get a presigned url to upload `filename` to S3.
""" """
path = "{}/api/presign".format(self.endpoint) path = "{}/api/presign".format(self.endpoint)
r = requests.post(path, headers={"authorization": "Bearer {}".format(token)}, json={"filename": filename}) r = requests.post(
path,
headers={"authorization": "Bearer {}".format(token)},
json={"filename": filename, "organization": organization},
)
r.raise_for_status() r.raise_for_status()
d = r.json() d = r.json()
return PresignedUrl(**d) return PresignedUrl(**d)
def presign_and_upload(self, token: str, filename: str, filepath: str) -> str: def presign_and_upload(self, token: str, filename: str, filepath: str, organization: Optional[str] = None) -> str:
""" """
Get a presigned url, then upload file to S3. Get a presigned url, then upload file to S3.
Outputs: Outputs:
url: Read-only url for the stored file on S3. url: Read-only url for the stored file on S3.
""" """
urls = self.presign(token, filename=filename) urls = self.presign(token, filename=filename, organization=organization)
# streaming upload: # streaming upload:
# https://2.python-requests.org/en/master/user/advanced/#streaming-uploads # https://2.python-requests.org/en/master/user/advanced/#streaming-uploads
# #
...@@ -159,22 +163,27 @@ class HfApi: ...@@ -159,22 +163,27 @@ class HfApi:
pf.close() pf.close()
return urls.access return urls.access
def list_objs(self, token: str) -> List[S3Obj]: def list_objs(self, token: str, organization: Optional[str] = None) -> List[S3Obj]:
""" """
Call HF API to list all stored files for user. Call HF API to list all stored files for user (or one of their organizations).
""" """
path = "{}/api/listObjs".format(self.endpoint) path = "{}/api/listObjs".format(self.endpoint)
r = requests.get(path, headers={"authorization": "Bearer {}".format(token)}) params = {"organization": organization} if organization is not None else None
r = requests.get(path, params=params, headers={"authorization": "Bearer {}".format(token)})
r.raise_for_status() r.raise_for_status()
d = r.json() d = r.json()
return [S3Obj(**x) for x in d] return [S3Obj(**x) for x in d]
def delete_obj(self, token: str, filename: str): def delete_obj(self, token: str, filename: str, organization: Optional[str] = None):
""" """
Call HF API to delete a file stored by user Call HF API to delete a file stored by user
""" """
path = "{}/api/deleteObj".format(self.endpoint) path = "{}/api/deleteObj".format(self.endpoint)
r = requests.delete(path, headers={"authorization": "Bearer {}".format(token)}, json={"filename": filename}) r = requests.delete(
path,
headers={"authorization": "Bearer {}".format(token)},
json={"filename": filename, "organization": organization},
)
r.raise_for_status() r.raise_for_status()
def model_list(self) -> List[ModelInfo]: def model_list(self) -> List[ModelInfo]:
......
...@@ -67,8 +67,17 @@ class HfApiEndpointsTest(HfApiCommonTest): ...@@ -67,8 +67,17 @@ class HfApiEndpointsTest(HfApiCommonTest):
cls._api.delete_obj(token=cls._token, filename=FILE_KEY) cls._api.delete_obj(token=cls._token, filename=FILE_KEY)
def test_whoami(self): def test_whoami(self):
user = self._api.whoami(token=self._token) user, orgs = self._api.whoami(token=self._token)
self.assertEqual(user, USER) self.assertEqual(user, USER)
self.assertIsInstance(orgs, list)
def test_presign_invalid_org(self):
with self.assertRaises(HTTPError):
_ = self._api.presign(token=self._token, filename="fake_org.txt", organization="fake")
def test_presign_valid_org(self):
urls = self._api.presign(token=self._token, filename="valid_org.txt", organization="valid_org")
self.assertIsInstance(urls, PresignedUrl)
def test_presign(self): def test_presign(self):
for FILE_KEY, FILE_PATH in FILES: for FILE_KEY, FILE_PATH in FILES:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment