Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
cbf8f5d3
Commit
cbf8f5d3
authored
Mar 09, 2020
by
Julien Chaumond
Browse files
[model upload] Support for organizations
parent
525b6b1c
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
50 additions
and
18 deletions
+50
-18
src/transformers/commands/user.py
src/transformers/commands/user.py
+19
-5
src/transformers/hf_api.py
src/transformers/hf_api.py
+21
-12
tests/test_hf_api.py
tests/test_hf_api.py
+10
-1
No files found.
src/transformers/commands/user.py
View file @
cbf8f5d3
...
...
@@ -26,13 +26,16 @@ class UserCommands(BaseTransformersCLICommand):
s3_parser
=
parser
.
add_parser
(
"s3"
,
help
=
"{ls, rm} Commands to interact with the files you upload on S3."
)
s3_subparsers
=
s3_parser
.
add_subparsers
(
help
=
"s3 related commands"
)
ls_parser
=
s3_subparsers
.
add_parser
(
"ls"
)
ls_parser
.
add_argument
(
"--organization"
,
type
=
str
,
help
=
"Optional: organization namespace."
)
ls_parser
.
set_defaults
(
func
=
lambda
args
:
ListObjsCommand
(
args
))
rm_parser
=
s3_subparsers
.
add_parser
(
"rm"
)
rm_parser
.
add_argument
(
"filename"
,
type
=
str
,
help
=
"individual object filename to delete from S3."
)
rm_parser
.
add_argument
(
"--organization"
,
type
=
str
,
help
=
"Optional: organization namespace."
)
rm_parser
.
set_defaults
(
func
=
lambda
args
:
DeleteObjCommand
(
args
))
# upload
upload_parser
=
parser
.
add_parser
(
"upload"
)
upload_parser
.
add_argument
(
"path"
,
type
=
str
,
help
=
"Local path of the folder or individual file to upload."
)
upload_parser
.
add_argument
(
"--organization"
,
type
=
str
,
help
=
"Optional: organization namespace."
)
upload_parser
.
add_argument
(
"--filename"
,
type
=
str
,
default
=
None
,
help
=
"Optional: override individual object filename on S3."
)
...
...
@@ -91,8 +94,10 @@ class WhoamiCommand(BaseUserCommand):
print
(
"Not logged in"
)
exit
()
try
:
user
=
self
.
_api
.
whoami
(
token
)
user
,
orgs
=
self
.
_api
.
whoami
(
token
)
print
(
user
)
if
orgs
:
print
(
ANSI
.
bold
(
"orgs: "
),
","
.
join
(
orgs
))
except
HTTPError
as
e
:
print
(
e
)
...
...
@@ -130,7 +135,7 @@ class ListObjsCommand(BaseUserCommand):
print
(
"Not logged in"
)
exit
(
1
)
try
:
objs
=
self
.
_api
.
list_objs
(
token
)
objs
=
self
.
_api
.
list_objs
(
token
,
organization
=
self
.
args
.
organization
)
except
HTTPError
as
e
:
print
(
e
)
exit
(
1
)
...
...
@@ -148,7 +153,7 @@ class DeleteObjCommand(BaseUserCommand):
print
(
"Not logged in"
)
exit
(
1
)
try
:
self
.
_api
.
delete_obj
(
token
,
filename
=
self
.
args
.
filename
)
self
.
_api
.
delete_obj
(
token
,
filename
=
self
.
args
.
filename
,
organization
=
self
.
args
.
organization
)
except
HTTPError
as
e
:
print
(
e
)
exit
(
1
)
...
...
@@ -195,8 +200,15 @@ class UploadCommand(BaseUserCommand):
)
exit
(
1
)
user
,
_
=
self
.
_api
.
whoami
(
token
)
namespace
=
self
.
args
.
organization
if
self
.
args
.
organization
is
not
None
else
user
for
filepath
,
filename
in
files
:
print
(
"About to upload file {} to S3 under filename {}"
.
format
(
ANSI
.
bold
(
filepath
),
ANSI
.
bold
(
filename
)))
print
(
"About to upload file {} to S3 under filename {} and namespace {}"
.
format
(
ANSI
.
bold
(
filepath
),
ANSI
.
bold
(
filename
),
ANSI
.
bold
(
namespace
)
)
)
choice
=
input
(
"Proceed? [Y/n] "
).
lower
()
if
not
(
choice
==
""
or
choice
==
"y"
or
choice
==
"yes"
):
...
...
@@ -204,6 +216,8 @@ class UploadCommand(BaseUserCommand):
exit
()
print
(
ANSI
.
bold
(
"Uploading... This might take a while if files are large"
))
for
filepath
,
filename
in
files
:
access_url
=
self
.
_api
.
presign_and_upload
(
token
=
token
,
filename
=
filename
,
filepath
=
filepath
)
access_url
=
self
.
_api
.
presign_and_upload
(
token
=
token
,
filename
=
filename
,
filepath
=
filepath
,
organization
=
self
.
args
.
organization
)
print
(
"Your file now lives at:"
)
print
(
access_url
)
src/transformers/hf_api.py
View file @
cbf8f5d3
...
...
@@ -17,7 +17,7 @@
import
io
import
os
from
os.path
import
expanduser
from
typing
import
Dict
,
List
,
Optional
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
requests
from
tqdm
import
tqdm
...
...
@@ -109,7 +109,7 @@ class HfApi:
d
=
r
.
json
()
return
d
[
"token"
]
def
whoami
(
self
,
token
:
str
)
->
str
:
def
whoami
(
self
,
token
:
str
)
->
Tuple
[
str
,
List
[
str
]]
:
"""
Call HF API to know "whoami"
"""
...
...
@@ -117,7 +117,7 @@ class HfApi:
r
=
requests
.
get
(
path
,
headers
=
{
"authorization"
:
"Bearer {}"
.
format
(
token
)})
r
.
raise_for_status
()
d
=
r
.
json
()
return
d
[
"user"
]
return
d
[
"user"
]
,
d
[
"orgs"
]
def
logout
(
self
,
token
:
str
)
->
None
:
"""
...
...
@@ -127,24 +127,28 @@ class HfApi:
r
=
requests
.
post
(
path
,
headers
=
{
"authorization"
:
"Bearer {}"
.
format
(
token
)})
r
.
raise_for_status
()
def
presign
(
self
,
token
:
str
,
filename
:
str
)
->
PresignedUrl
:
def
presign
(
self
,
token
:
str
,
filename
:
str
,
organization
:
Optional
[
str
]
=
None
)
->
PresignedUrl
:
"""
Call HF API to get a presigned url to upload `filename` to S3.
"""
path
=
"{}/api/presign"
.
format
(
self
.
endpoint
)
r
=
requests
.
post
(
path
,
headers
=
{
"authorization"
:
"Bearer {}"
.
format
(
token
)},
json
=
{
"filename"
:
filename
})
r
=
requests
.
post
(
path
,
headers
=
{
"authorization"
:
"Bearer {}"
.
format
(
token
)},
json
=
{
"filename"
:
filename
,
"organization"
:
organization
},
)
r
.
raise_for_status
()
d
=
r
.
json
()
return
PresignedUrl
(
**
d
)
def
presign_and_upload
(
self
,
token
:
str
,
filename
:
str
,
filepath
:
str
)
->
str
:
def
presign_and_upload
(
self
,
token
:
str
,
filename
:
str
,
filepath
:
str
,
organization
:
Optional
[
str
]
=
None
)
->
str
:
"""
Get a presigned url, then upload file to S3.
Outputs:
url: Read-only url for the stored file on S3.
"""
urls
=
self
.
presign
(
token
,
filename
=
filename
)
urls
=
self
.
presign
(
token
,
filename
=
filename
,
organization
=
organization
)
# streaming upload:
# https://2.python-requests.org/en/master/user/advanced/#streaming-uploads
#
...
...
@@ -159,22 +163,27 @@ class HfApi:
pf
.
close
()
return
urls
.
access
def
list_objs
(
self
,
token
:
str
)
->
List
[
S3Obj
]:
def
list_objs
(
self
,
token
:
str
,
organization
:
Optional
[
str
]
=
None
)
->
List
[
S3Obj
]:
"""
Call HF API to list all stored files for user.
Call HF API to list all stored files for user
(or one of their organizations)
.
"""
path
=
"{}/api/listObjs"
.
format
(
self
.
endpoint
)
r
=
requests
.
get
(
path
,
headers
=
{
"authorization"
:
"Bearer {}"
.
format
(
token
)})
params
=
{
"organization"
:
organization
}
if
organization
is
not
None
else
None
r
=
requests
.
get
(
path
,
params
=
params
,
headers
=
{
"authorization"
:
"Bearer {}"
.
format
(
token
)})
r
.
raise_for_status
()
d
=
r
.
json
()
return
[
S3Obj
(
**
x
)
for
x
in
d
]
def
delete_obj
(
self
,
token
:
str
,
filename
:
str
):
def
delete_obj
(
self
,
token
:
str
,
filename
:
str
,
organization
:
Optional
[
str
]
=
None
):
"""
Call HF API to delete a file stored by user
"""
path
=
"{}/api/deleteObj"
.
format
(
self
.
endpoint
)
r
=
requests
.
delete
(
path
,
headers
=
{
"authorization"
:
"Bearer {}"
.
format
(
token
)},
json
=
{
"filename"
:
filename
})
r
=
requests
.
delete
(
path
,
headers
=
{
"authorization"
:
"Bearer {}"
.
format
(
token
)},
json
=
{
"filename"
:
filename
,
"organization"
:
organization
},
)
r
.
raise_for_status
()
def
model_list
(
self
)
->
List
[
ModelInfo
]:
...
...
tests/test_hf_api.py
View file @
cbf8f5d3
...
...
@@ -67,8 +67,17 @@ class HfApiEndpointsTest(HfApiCommonTest):
cls
.
_api
.
delete_obj
(
token
=
cls
.
_token
,
filename
=
FILE_KEY
)
def
test_whoami
(
self
):
user
=
self
.
_api
.
whoami
(
token
=
self
.
_token
)
user
,
orgs
=
self
.
_api
.
whoami
(
token
=
self
.
_token
)
self
.
assertEqual
(
user
,
USER
)
self
.
assertIsInstance
(
orgs
,
list
)
def
test_presign_invalid_org
(
self
):
with
self
.
assertRaises
(
HTTPError
):
_
=
self
.
_api
.
presign
(
token
=
self
.
_token
,
filename
=
"fake_org.txt"
,
organization
=
"fake"
)
def
test_presign_valid_org
(
self
):
urls
=
self
.
_api
.
presign
(
token
=
self
.
_token
,
filename
=
"valid_org.txt"
,
organization
=
"valid_org"
)
self
.
assertIsInstance
(
urls
,
PresignedUrl
)
def
test_presign
(
self
):
for
FILE_KEY
,
FILE_PATH
in
FILES
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment