Unverified Commit 1eb89ddf authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #2044 from huggingface/cli_upload

CLI for authenticated file sharing
parents fb0d2f1d 3ba417e1
...@@ -36,6 +36,12 @@ To create the package for pypi. ...@@ -36,6 +36,12 @@ To create the package for pypi.
from io import open from io import open
from setuptools import find_packages, setup from setuptools import find_packages, setup
extras = {
'serving': ['uvicorn', 'fastapi']
}
extras['all'] = [package for package in extras.values()]
setup( setup(
name="transformers", name="transformers",
version="2.2.1", version="2.2.1",
...@@ -61,6 +67,10 @@ setup( ...@@ -61,6 +67,10 @@ setup(
"transformers=transformers.__main__:main", "transformers=transformers.__main__:main",
] ]
}, },
extras_require=extras,
scripts=[
'transformers-cli'
],
# python_requires='>=3.5.0', # python_requires='>=3.5.0',
tests_require=['pytest'], tests_require=['pytest'],
classifiers=[ classifiers=[
......
#!/usr/bin/env python
from argparse import ArgumentParser
from transformers.commands.user import UserCommands
if __name__ == '__main__':
parser = ArgumentParser(description='Transformers CLI tool', usage='transformers-cli <command> [<args>]')
commands_parser = parser.add_subparsers(help='transformers-cli command helpers')
# Register commands
UserCommands.register_subcommand(commands_parser)
# Let's go
args = parser.parse_args()
if not hasattr(args, 'func'):
parser.print_help()
exit(1)
# Run
service = args.func(args)
service.run()
from abc import ABC, abstractmethod
from argparse import ArgumentParser
class BaseTransformersCLICommand(ABC):
@staticmethod
@abstractmethod
def register_subcommand(parser: ArgumentParser):
raise NotImplementedError()
@abstractmethod
def run(self):
raise NotImplementedError()
from argparse import ArgumentParser
from getpass import getpass
import os
from transformers.commands import BaseTransformersCLICommand
from transformers.hf_api import HfApi, HfFolder, HTTPError
class UserCommands(BaseTransformersCLICommand):
@staticmethod
def register_subcommand(parser: ArgumentParser):
login_parser = parser.add_parser('login')
login_parser.set_defaults(func=lambda args: LoginCommand(args))
whoami_parser = parser.add_parser('whoami')
whoami_parser.set_defaults(func=lambda args: WhoamiCommand(args))
logout_parser = parser.add_parser('logout')
logout_parser.set_defaults(func=lambda args: LogoutCommand(args))
list_parser = parser.add_parser('ls')
list_parser.set_defaults(func=lambda args: ListObjsCommand(args))
# upload
upload_parser = parser.add_parser('upload')
upload_parser.add_argument('file', type=str, help='Local filepath of the file to upload.')
upload_parser.add_argument('--filename', type=str, default=None, help='Optional: override object filename on S3.')
upload_parser.set_defaults(func=lambda args: UploadCommand(args))
class ANSI:
"""
Helper for en.wikipedia.org/wiki/ANSI_escape_code
"""
_bold = u"\u001b[1m"
_reset = u"\u001b[0m"
@classmethod
def bold(cls, s):
return "{}{}{}".format(cls._bold, s, cls._reset)
class BaseUserCommand:
def __init__(self, args):
self.args = args
self._api = HfApi()
class LoginCommand(BaseUserCommand):
def run(self):
print("""
_| _| _| _| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _|_|_|_| _|_| _|_|_| _|_|_|_|
_| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|
_|_|_|_| _| _| _| _|_| _| _|_| _| _| _| _| _| _|_| _|_|_| _|_|_|_| _| _|_|_|
_| _| _| _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|
_| _| _|_| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _| _| _| _|_|_| _|_|_|_|
""")
username = input("Username: ")
password = getpass()
try:
token = self._api.login(username, password)
except HTTPError as e:
# probably invalid credentials, display error message.
print(e)
exit(1)
HfFolder.save_token(token)
print("Login successful")
print("Your token:", token, "\n")
print("Your token has been saved to", HfFolder.path_token)
class WhoamiCommand(BaseUserCommand):
def run(self):
token = HfFolder.get_token()
if token is None:
print("Not logged in")
exit()
try:
user = self._api.whoami(token)
print(user)
except HTTPError as e:
print(e)
class LogoutCommand(BaseUserCommand):
def run(self):
token = HfFolder.get_token()
if token is None:
print("Not logged in")
exit()
HfFolder.delete_token()
self._api.logout(token)
print("Successfully logged out.")
class ListObjsCommand(BaseUserCommand):
def tabulate(self, rows, headers):
# type: (List[List[Union[str, int]]], List[str]) -> str
"""
Inspired by:
stackoverflow.com/a/8356620/593036
stackoverflow.com/questions/9535954/printing-lists-as-tabular-data
"""
col_widths = [max(len(str(x)) for x in col) for col in zip(*rows, headers)]
row_format = ("{{:{}}} " * len(headers)).format(*col_widths)
lines = []
lines.append(
row_format.format(*headers)
)
lines.append(
row_format.format(*["-" * w for w in col_widths])
)
for row in rows:
lines.append(
row_format.format(*row)
)
return "\n".join(lines)
def run(self):
token = HfFolder.get_token()
if token is None:
print("Not logged in")
exit(1)
try:
objs = self._api.list_objs(token)
except HTTPError as e:
print(e)
exit(1)
if len(objs) == 0:
print("No shared file yet")
exit()
rows = [ [
obj.filename,
obj.LastModified,
obj.ETag,
obj.Size
] for obj in objs ]
print(
self.tabulate(rows, headers=["Filename", "LastModified", "ETag", "Size"])
)
class UploadCommand(BaseUserCommand):
def run(self):
token = HfFolder.get_token()
if token is None:
print("Not logged in")
exit(1)
filepath = os.path.join(os.getcwd(), self.args.file)
filename = self.args.filename if self.args.filename is not None else os.path.basename(filepath)
print(
"About to upload file {} to S3 under filename {}".format(
ANSI.bold(filepath), ANSI.bold(filename)
)
)
choice = input("Proceed? [Y/n] ").lower()
if not(choice == "" or choice == "y" or choice == "yes"):
print("Abort")
exit()
print(
ANSI.bold("Uploading... This might take a while if file is large")
)
access_url = self._api.presign_and_upload(
token=token, filename=filename, filepath=filepath
)
print("Your file now lives at:")
print(access_url)
# coding=utf-8
# Copyright 2019-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function
import os
from os.path import expanduser
import six
import requests
from requests.exceptions import HTTPError
ENDPOINT = "https://huggingface.co"
class S3Obj:
def __init__(
self,
filename, # type: str
LastModified, # type: str
ETag, # type: str
Size, # type: int
**kwargs
):
self.filename = filename
self.LastModified = LastModified
self.ETag = ETag
self.Size = Size
class PresignedUrl:
def __init__(
self,
write, # type: str
access, # type: str
type, # type: str
**kwargs
):
self.write = write
self.access = access
self.type = type # mime-type to send to S3.
class HfApi:
def __init__(self, endpoint=None):
self.endpoint = endpoint if endpoint is not None else ENDPOINT
def login(
self,
username, # type: str
password, # type: str
):
# type: (...) -> str
"""
Call HF API to sign in a user and get a token if credentials are valid.
Outputs:
token if credentials are valid
Throws:
requests.exceptions.HTTPError if credentials are invalid
"""
path = "{}/api/login".format(self.endpoint)
r = requests.post(path, json={"username": username, "password": password})
r.raise_for_status()
d = r.json()
return d["token"]
def whoami(
self,
token, # type: str
):
# type: (...) -> str
"""
Call HF API to know "whoami"
"""
path = "{}/api/whoami".format(self.endpoint)
r = requests.get(path, headers={"authorization": "Bearer {}".format(token)})
r.raise_for_status()
d = r.json()
return d["user"]
def logout(self, token):
# type: (...) -> void
"""
Call HF API to log out.
"""
path = "{}/api/logout".format(self.endpoint)
r = requests.post(path, headers={"authorization": "Bearer {}".format(token)})
r.raise_for_status()
def presign(self, token, filename):
# type: (...) -> PresignedUrl
"""
Call HF API to get a presigned url to upload `filename` to S3.
"""
path = "{}/api/presign".format(self.endpoint)
r = requests.post(
path,
headers={"authorization": "Bearer {}".format(token)},
json={"filename": filename},
)
r.raise_for_status()
d = r.json()
return PresignedUrl(**d)
def presign_and_upload(self, token, filename, filepath):
# type: (...) -> str
"""
Get a presigned url, then upload file to S3.
Outputs:
url: Read-only url for the stored file on S3.
"""
urls = self.presign(token, filename=filename)
# streaming upload:
# https://2.python-requests.org/en/master/user/advanced/#streaming-uploads
#
# Even though we presign with the correct content-type,
# the client still has to specify it when uploading the file.
with open(filepath, "rb") as f:
r = requests.put(urls.write, data=f, headers={
"content-type": urls.type,
})
r.raise_for_status()
return urls.access
def list_objs(self, token):
# type: (...) -> List[S3Obj]
"""
Call HF API to list all stored files for user.
"""
path = "{}/api/listObjs".format(self.endpoint)
r = requests.get(path, headers={"authorization": "Bearer {}".format(token)})
r.raise_for_status()
d = r.json()
return [S3Obj(**x) for x in d]
class HfFolder:
path_token = expanduser("~/.huggingface/token")
@classmethod
def save_token(cls, token):
"""
Save token, creating folder as needed.
"""
if six.PY3:
os.makedirs(os.path.dirname(cls.path_token), exist_ok=True)
else:
# Python 2
try:
os.makedirs(os.path.dirname(cls.path_token))
except OSError as e:
if e.errno != os.errno.EEXIST:
raise e
pass
with open(cls.path_token, 'w+') as f:
f.write(token)
@classmethod
def get_token(cls):
"""
Get token or None if not existent.
"""
try:
with open(cls.path_token, 'r') as f:
return f.read()
except:
# this is too wide. When Py2 is dead use:
# `except FileNotFoundError:` instead
return None
@classmethod
def delete_token(cls):
"""
Delete token.
Do not fail if token does not exist.
"""
try:
os.remove(cls.path_token)
except:
return
# coding=utf-8
# Copyright 2019-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function
import os
import six
import time
import unittest
from transformers.hf_api import HfApi, S3Obj, PresignedUrl, HfFolder, HTTPError
USER = "__DUMMY_TRANSFORMERS_USER__"
PASS = "__DUMMY_TRANSFORMERS_PASS__"
FILE_KEY = "Test-{}.txt".format(int(time.time()))
FILE_PATH = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "fixtures/input.txt"
)
class HfApiCommonTest(unittest.TestCase):
_api = HfApi(endpoint="https://moon-staging.huggingface.co")
class HfApiLoginTest(HfApiCommonTest):
def test_login_invalid(self):
with self.assertRaises(HTTPError):
self._api.login(username=USER, password="fake")
def test_login_valid(self):
token = self._api.login(username=USER, password=PASS)
self.assertIsInstance(token, six.string_types)
class HfApiEndpointsTest(HfApiCommonTest):
@classmethod
def setUpClass(cls):
"""
Share this valid token in all tests below.
"""
cls._token = cls._api.login(username=USER, password=PASS)
def test_whoami(self):
user = self._api.whoami(token=self._token)
self.assertEqual(user, USER)
def test_presign(self):
urls = self._api.presign(token=self._token, filename=FILE_KEY)
self.assertIsInstance(urls, PresignedUrl)
self.assertEqual(urls.type, "text/plain")
def test_presign_and_upload(self):
access_url = self._api.presign_and_upload(
token=self._token, filename=FILE_KEY, filepath=FILE_PATH
)
self.assertIsInstance(access_url, six.string_types)
def test_list_objs(self):
objs = self._api.list_objs(token=self._token)
self.assertIsInstance(objs, list)
if len(objs) > 0:
o = objs[-1]
self.assertIsInstance(o, S3Obj)
class HfFolderTest(unittest.TestCase):
def test_token_workflow(self):
"""
Test the whole token save/get/delete workflow,
with the desired behavior with respect to non-existent tokens.
"""
token = "token-{}".format(int(time.time()))
HfFolder.save_token(token)
self.assertEqual(
HfFolder.get_token(),
token
)
HfFolder.delete_token()
HfFolder.delete_token()
# ^^ not an error, we test that the
# second call does not fail.
self.assertEqual(
HfFolder.get_token(),
None
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment