Commit 96fa9a8a authored by Julien Chaumond's avatar Julien Chaumond
Browse files

Python 2 + Post mime-type to S3

parent e4fbf3e2
......@@ -14,9 +14,9 @@
# limitations under the License.
from __future__ import absolute_import, division, print_function
from typing import List, NamedTuple
import os
from os.path import expanduser
import six
import requests
from requests.exceptions import HTTPError
......@@ -24,23 +24,43 @@ from requests.exceptions import HTTPError
ENDPOINT = "https://huggingface.co"
class S3Obj:
def __init__(self, filename: str, LastModified: str, ETag: str, Size: int):
def __init__(
self,
filename, # type: str
LastModified, # type: str
ETag, # type: str
Size, # type: int
**kwargs
):
self.filename = filename
self.LastModified = LastModified
self.ETag = ETag
self.Size = Size
class PresignedUrl(NamedTuple):
write: str
access: str
class PresignedUrl:
def __init__(
self,
write, # type: str
access, # type: str
type, # type: str
**kwargs
):
self.write = write
self.access = access
self.type = type # mime-type to send to S3.
class HfApi:
def __init__(self, endpoint=None):
self.endpoint = endpoint if endpoint is not None else ENDPOINT
def login(self, username: str, password: str) -> str:
def login(
self,
username, # type: str
password, # type: str
):
# type: (...) -> str
"""
Call HF API to sign in a user and get a token if credentials are valid.
......@@ -56,7 +76,11 @@ class HfApi:
d = r.json()
return d["token"]
def whoami(self, token: str) -> str:
def whoami(
self,
token, # type: str
):
# type: (...) -> str
"""
Call HF API to know "whoami"
"""
......@@ -66,7 +90,8 @@ class HfApi:
d = r.json()
return d["user"]
def logout(self, token: str):
def logout(self, token):
# type: (...) -> void
"""
Call HF API to log out.
"""
......@@ -74,7 +99,8 @@ class HfApi:
r = requests.post(path, headers={"authorization": "Bearer {}".format(token)})
r.raise_for_status()
def presign(self, token: str, filename: str) -> PresignedUrl:
def presign(self, token, filename):
# type: (...) -> PresignedUrl
"""
Call HF API to get a presigned url to upload `filename` to S3.
"""
......@@ -88,7 +114,8 @@ class HfApi:
d = r.json()
return PresignedUrl(**d)
def presign_and_upload(self, token: str, filename: str, filepath: str) -> str:
def presign_and_upload(self, token, filename, filepath):
# type: (...) -> str
"""
Get a presigned url, then upload file to S3.
......@@ -98,12 +125,18 @@ class HfApi:
urls = self.presign(token, filename=filename)
# streaming upload:
# https://2.python-requests.org/en/master/user/advanced/#streaming-uploads
#
# Even though we presign with the correct content-type,
# the client still has to specify it when uploading the file.
with open(filepath, "rb") as f:
r = requests.put(urls.write, data=f)
r = requests.put(urls.write, data=f, headers={
"content-type": urls.type,
})
r.raise_for_status()
return urls.access
def list_objs(self, token: str) -> List[S3Obj]:
def list_objs(self, token):
# type: (...) -> List[S3Obj]
"""
Call HF API to list all stored files for user.
"""
......@@ -121,11 +154,20 @@ class HfFolder:
path_token = expanduser("~/.huggingface/token")
@classmethod
def save_token(cls, token: str):
def save_token(cls, token):
"""
Save token, creating folder as needed.
"""
if six.PY3:
os.makedirs(os.path.dirname(cls.path_token), exist_ok=True)
else:
# Python 2
try:
os.makedirs(os.path.dirname(cls.path_token))
except OSError as e:
if e.errno != os.errno.EEXIST:
raise e
pass
with open(cls.path_token, 'w+') as f:
f.write(token)
......@@ -137,7 +179,9 @@ class HfFolder:
try:
with open(cls.path_token, 'r') as f:
return f.read()
except FileNotFoundError:
except:
# this is too wide. When Py2 is dead use:
# `except FileNotFoundError:` instead
return None
@classmethod
......
......@@ -15,6 +15,7 @@
from __future__ import absolute_import, division, print_function
import os
import six
import time
import unittest
......@@ -40,7 +41,7 @@ class HfApiLoginTest(HfApiCommonTest):
def test_login_valid(self):
token = self._api.login(username=USER, password=PASS)
self.assertIsInstance(token, str)
self.assertIsInstance(token, six.string_types)
class HfApiEndpointsTest(HfApiCommonTest):
......@@ -56,17 +57,20 @@ class HfApiEndpointsTest(HfApiCommonTest):
self.assertEqual(user, USER)
def test_presign(self):
url = self._api.presign(token=self._token, filename=FILE_KEY)
self.assertIsInstance(url, PresignedUrl)
urls = self._api.presign(token=self._token, filename=FILE_KEY)
self.assertIsInstance(urls, PresignedUrl)
self.assertEqual(urls.type, "text/plain")
def test_presign_and_upload(self):
access_url = self._api.presign_and_upload(
token=self._token, filename=FILE_KEY, filepath=FILE_PATH
)
self.assertIsInstance(access_url, str)
self.assertIsInstance(access_url, six.string_types)
def test_list_objs(self):
objs = self._api.list_objs(token=self._token)
self.assertIsInstance(objs, list)
if len(objs) > 0:
o = objs[-1]
self.assertIsInstance(o, S3Obj)
......@@ -92,3 +96,7 @@ class HfFolderTest(unittest.TestCase):
HfFolder.get_token(),
None
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment