Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
f564f93c
Commit
f564f93c
authored
Mar 04, 2020
by
Julien Chaumond
Browse files
[hf_api] Get the public list of all the models on huggingface
parent
ff9e79ba
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
74 additions
and
3 deletions
+74
-3
src/transformers/hf_api.py
src/transformers/hf_api.py
+59
-1
tests/test_hf_api.py
tests/test_hf_api.py
+15
-2
No files found.
src/transformers/hf_api.py
View file @
f564f93c
...
@@ -17,7 +17,7 @@
...
@@ -17,7 +17,7 @@
import
io
import
io
import
os
import
os
from
os.path
import
expanduser
from
os.path
import
expanduser
from
typing
import
List
from
typing
import
Dict
,
List
,
Optional
import
requests
import
requests
from
tqdm
import
tqdm
from
tqdm
import
tqdm
...
@@ -27,6 +27,10 @@ ENDPOINT = "https://huggingface.co"
...
@@ -27,6 +27,10 @@ ENDPOINT = "https://huggingface.co"
class
S3Obj
:
class
S3Obj
:
"""
Data structure that represents a file belonging to the current user.
"""
def
__init__
(
self
,
filename
:
str
,
LastModified
:
str
,
ETag
:
str
,
Size
:
int
,
**
kwargs
):
def
__init__
(
self
,
filename
:
str
,
LastModified
:
str
,
ETag
:
str
,
Size
:
int
,
**
kwargs
):
self
.
filename
=
filename
self
.
filename
=
filename
self
.
LastModified
=
LastModified
self
.
LastModified
=
LastModified
...
@@ -41,6 +45,50 @@ class PresignedUrl:
...
@@ -41,6 +45,50 @@ class PresignedUrl:
self
.
type
=
type
# mime-type to send to S3.
self
.
type
=
type
# mime-type to send to S3.
class
S3Object
:
"""
Data structure that represents a public file accessible on our S3.
"""
def
__init__
(
self
,
key
:
str
,
# S3 object key
etag
:
str
,
lastModified
:
str
,
size
:
int
,
rfilename
:
str
,
# filename relative to config.json
**
kwargs
):
self
.
key
=
key
self
.
etag
=
etag
self
.
lastModified
=
lastModified
self
.
size
=
size
self
.
rfilename
=
rfilename
class
ModelInfo
:
"""
Info about a public model accessible from our S3.
"""
def
__init__
(
self
,
modelId
:
str
,
# id of model
key
:
str
,
# S3 object key of config.json
author
:
Optional
[
str
]
=
None
,
downloads
:
Optional
[
int
]
=
None
,
tags
:
List
[
str
]
=
[],
siblings
:
List
[
Dict
]
=
[],
**
kwargs
):
self
.
modelId
=
modelId
self
.
key
=
key
self
.
author
=
author
self
.
downloads
=
downloads
self
.
tags
=
tags
self
.
siblings
=
[
S3Object
(
**
x
)
for
x
in
siblings
]
class
HfApi
:
class
HfApi
:
def
__init__
(
self
,
endpoint
=
None
):
def
__init__
(
self
,
endpoint
=
None
):
self
.
endpoint
=
endpoint
if
endpoint
is
not
None
else
ENDPOINT
self
.
endpoint
=
endpoint
if
endpoint
is
not
None
else
ENDPOINT
...
@@ -129,6 +177,16 @@ class HfApi:
...
@@ -129,6 +177,16 @@ class HfApi:
r
=
requests
.
delete
(
path
,
headers
=
{
"authorization"
:
"Bearer {}"
.
format
(
token
)},
json
=
{
"filename"
:
filename
})
r
=
requests
.
delete
(
path
,
headers
=
{
"authorization"
:
"Bearer {}"
.
format
(
token
)},
json
=
{
"filename"
:
filename
})
r
.
raise_for_status
()
r
.
raise_for_status
()
def
model_list
(
self
)
->
List
[
ModelInfo
]:
"""
Get the public list of all the models on huggingface, including the community models
"""
path
=
"{}/api/models"
.
format
(
self
.
endpoint
)
r
=
requests
.
get
(
path
)
r
.
raise_for_status
()
d
=
r
.
json
()
return
[
ModelInfo
(
**
x
)
for
x
in
d
]
class
TqdmProgressFileReader
:
class
TqdmProgressFileReader
:
"""
"""
...
...
tests/test_hf_api.py
View file @
f564f93c
...
@@ -21,7 +21,7 @@ import unittest
...
@@ -21,7 +21,7 @@ import unittest
import
requests
import
requests
from
requests.exceptions
import
HTTPError
from
requests.exceptions
import
HTTPError
from
transformers.hf_api
import
HfApi
,
HfFolder
,
PresignedUrl
,
S3Obj
from
transformers.hf_api
import
HfApi
,
HfFolder
,
ModelInfo
,
PresignedUrl
,
S3Obj
USER
=
"__DUMMY_TRANSFORMERS_USER__"
USER
=
"__DUMMY_TRANSFORMERS_USER__"
...
@@ -36,10 +36,11 @@ FILES = [
...
@@ -36,10 +36,11 @@ FILES = [
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
"fixtures/empty.txt"
),
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
"fixtures/empty.txt"
),
),
),
]
]
ENDPOINT_STAGING
=
"https://moon-staging.huggingface.co"
class
HfApiCommonTest
(
unittest
.
TestCase
):
class
HfApiCommonTest
(
unittest
.
TestCase
):
_api
=
HfApi
(
endpoint
=
"https://moon-staging.huggingface.co"
)
_api
=
HfApi
(
endpoint
=
ENDPOINT_STAGING
)
class
HfApiLoginTest
(
HfApiCommonTest
):
class
HfApiLoginTest
(
HfApiCommonTest
):
...
@@ -92,6 +93,18 @@ class HfApiEndpointsTest(HfApiCommonTest):
...
@@ -92,6 +93,18 @@ class HfApiEndpointsTest(HfApiCommonTest):
self
.
assertIsInstance
(
o
,
S3Obj
)
self
.
assertIsInstance
(
o
,
S3Obj
)
class
HfApiPublicTest
(
unittest
.
TestCase
):
def
test_staging_model_list
(
self
):
_api
=
HfApi
(
endpoint
=
ENDPOINT_STAGING
)
_
=
_api
.
model_list
()
def
test_model_list
(
self
):
_api
=
HfApi
()
models
=
_api
.
model_list
()
self
.
assertGreater
(
len
(
models
),
100
)
self
.
assertIsInstance
(
models
[
0
],
ModelInfo
)
class
HfFolderTest
(
unittest
.
TestCase
):
class
HfFolderTest
(
unittest
.
TestCase
):
def
test_token_workflow
(
self
):
def
test_token_workflow
(
self
):
"""
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment