"...resnet50_tensorflow.git" did not exist on "eb1498508a745e44e7ca72be9a2565d49708be43"
Commit 96fa9a8a authored by Julien Chaumond's avatar Julien Chaumond
Browse files

Python 2 + Post mime-type to S3

parent e4fbf3e2
...@@ -14,9 +14,9 @@ ...@@ -14,9 +14,9 @@
# limitations under the License. # limitations under the License.
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
from typing import List, NamedTuple
import os import os
from os.path import expanduser from os.path import expanduser
import six
import requests import requests
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
...@@ -24,23 +24,43 @@ from requests.exceptions import HTTPError ...@@ -24,23 +24,43 @@ from requests.exceptions import HTTPError
ENDPOINT = "https://huggingface.co" ENDPOINT = "https://huggingface.co"
class S3Obj: class S3Obj:
def __init__(self, filename: str, LastModified: str, ETag: str, Size: int): def __init__(
self,
filename, # type: str
LastModified, # type: str
ETag, # type: str
Size, # type: int
**kwargs
):
self.filename = filename self.filename = filename
self.LastModified = LastModified self.LastModified = LastModified
self.ETag = ETag self.ETag = ETag
self.Size = Size self.Size = Size
class PresignedUrl(NamedTuple): class PresignedUrl:
write: str def __init__(
access: str self,
write, # type: str
access, # type: str
type, # type: str
**kwargs
):
self.write = write
self.access = access
self.type = type # mime-type to send to S3.
class HfApi: class HfApi:
def __init__(self, endpoint=None): def __init__(self, endpoint=None):
self.endpoint = endpoint if endpoint is not None else ENDPOINT self.endpoint = endpoint if endpoint is not None else ENDPOINT
def login(self, username: str, password: str) -> str: def login(
self,
username, # type: str
password, # type: str
):
# type: (...) -> str
""" """
Call HF API to sign in a user and get a token if credentials are valid. Call HF API to sign in a user and get a token if credentials are valid.
...@@ -56,7 +76,11 @@ class HfApi: ...@@ -56,7 +76,11 @@ class HfApi:
d = r.json() d = r.json()
return d["token"] return d["token"]
def whoami(self, token: str) -> str: def whoami(
self,
token, # type: str
):
# type: (...) -> str
""" """
Call HF API to know "whoami" Call HF API to know "whoami"
""" """
...@@ -66,7 +90,8 @@ class HfApi: ...@@ -66,7 +90,8 @@ class HfApi:
d = r.json() d = r.json()
return d["user"] return d["user"]
def logout(self, token: str): def logout(self, token):
# type: (...) -> void
""" """
Call HF API to log out. Call HF API to log out.
""" """
...@@ -74,7 +99,8 @@ class HfApi: ...@@ -74,7 +99,8 @@ class HfApi:
r = requests.post(path, headers={"authorization": "Bearer {}".format(token)}) r = requests.post(path, headers={"authorization": "Bearer {}".format(token)})
r.raise_for_status() r.raise_for_status()
def presign(self, token: str, filename: str) -> PresignedUrl: def presign(self, token, filename):
# type: (...) -> PresignedUrl
""" """
Call HF API to get a presigned url to upload `filename` to S3. Call HF API to get a presigned url to upload `filename` to S3.
""" """
...@@ -88,7 +114,8 @@ class HfApi: ...@@ -88,7 +114,8 @@ class HfApi:
d = r.json() d = r.json()
return PresignedUrl(**d) return PresignedUrl(**d)
def presign_and_upload(self, token: str, filename: str, filepath: str) -> str: def presign_and_upload(self, token, filename, filepath):
# type: (...) -> str
""" """
Get a presigned url, then upload file to S3. Get a presigned url, then upload file to S3.
...@@ -98,12 +125,18 @@ class HfApi: ...@@ -98,12 +125,18 @@ class HfApi:
urls = self.presign(token, filename=filename) urls = self.presign(token, filename=filename)
# streaming upload: # streaming upload:
# https://2.python-requests.org/en/master/user/advanced/#streaming-uploads # https://2.python-requests.org/en/master/user/advanced/#streaming-uploads
#
# Even though we presign with the correct content-type,
# the client still has to specify it when uploading the file.
with open(filepath, "rb") as f: with open(filepath, "rb") as f:
r = requests.put(urls.write, data=f) r = requests.put(urls.write, data=f, headers={
"content-type": urls.type,
})
r.raise_for_status() r.raise_for_status()
return urls.access return urls.access
def list_objs(self, token: str) -> List[S3Obj]: def list_objs(self, token):
# type: (...) -> List[S3Obj]
""" """
Call HF API to list all stored files for user. Call HF API to list all stored files for user.
""" """
...@@ -121,11 +154,20 @@ class HfFolder: ...@@ -121,11 +154,20 @@ class HfFolder:
path_token = expanduser("~/.huggingface/token") path_token = expanduser("~/.huggingface/token")
@classmethod @classmethod
def save_token(cls, token: str): def save_token(cls, token):
""" """
Save token, creating folder as needed. Save token, creating folder as needed.
""" """
if six.PY3:
os.makedirs(os.path.dirname(cls.path_token), exist_ok=True) os.makedirs(os.path.dirname(cls.path_token), exist_ok=True)
else:
# Python 2
try:
os.makedirs(os.path.dirname(cls.path_token))
except OSError as e:
if e.errno != os.errno.EEXIST:
raise e
pass
with open(cls.path_token, 'w+') as f: with open(cls.path_token, 'w+') as f:
f.write(token) f.write(token)
...@@ -137,7 +179,9 @@ class HfFolder: ...@@ -137,7 +179,9 @@ class HfFolder:
try: try:
with open(cls.path_token, 'r') as f: with open(cls.path_token, 'r') as f:
return f.read() return f.read()
except FileNotFoundError: except:
# this is too wide. When Py2 is dead use:
# `except FileNotFoundError:` instead
return None return None
@classmethod @classmethod
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import os import os
import six
import time import time
import unittest import unittest
...@@ -40,7 +41,7 @@ class HfApiLoginTest(HfApiCommonTest): ...@@ -40,7 +41,7 @@ class HfApiLoginTest(HfApiCommonTest):
def test_login_valid(self): def test_login_valid(self):
token = self._api.login(username=USER, password=PASS) token = self._api.login(username=USER, password=PASS)
self.assertIsInstance(token, str) self.assertIsInstance(token, six.string_types)
class HfApiEndpointsTest(HfApiCommonTest): class HfApiEndpointsTest(HfApiCommonTest):
...@@ -56,17 +57,20 @@ class HfApiEndpointsTest(HfApiCommonTest): ...@@ -56,17 +57,20 @@ class HfApiEndpointsTest(HfApiCommonTest):
self.assertEqual(user, USER) self.assertEqual(user, USER)
def test_presign(self): def test_presign(self):
url = self._api.presign(token=self._token, filename=FILE_KEY) urls = self._api.presign(token=self._token, filename=FILE_KEY)
self.assertIsInstance(url, PresignedUrl) self.assertIsInstance(urls, PresignedUrl)
self.assertEqual(urls.type, "text/plain")
def test_presign_and_upload(self): def test_presign_and_upload(self):
access_url = self._api.presign_and_upload( access_url = self._api.presign_and_upload(
token=self._token, filename=FILE_KEY, filepath=FILE_PATH token=self._token, filename=FILE_KEY, filepath=FILE_PATH
) )
self.assertIsInstance(access_url, str) self.assertIsInstance(access_url, six.string_types)
def test_list_objs(self): def test_list_objs(self):
objs = self._api.list_objs(token=self._token) objs = self._api.list_objs(token=self._token)
self.assertIsInstance(objs, list)
if len(objs) > 0:
o = objs[-1] o = objs[-1]
self.assertIsInstance(o, S3Obj) self.assertIsInstance(o, S3Obj)
...@@ -92,3 +96,7 @@ class HfFolderTest(unittest.TestCase): ...@@ -92,3 +96,7 @@ class HfFolderTest(unittest.TestCase):
HfFolder.get_token(), HfFolder.get_token(),
None None
) )
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment