Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
3e5da38d
Unverified
Commit
3e5da38d
authored
Mar 06, 2020
by
Thomas Wolf
Committed by
GitHub
Mar 06, 2020
Browse files
Merge pull request #3132 from huggingface/hf_api_model_list
[hf_api] Get the public list of all the models on huggingface
parents
9499a377
3f067f44
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
74 additions
and
3 deletions
+74
-3
src/transformers/hf_api.py
src/transformers/hf_api.py
+59
-1
tests/test_hf_api.py
tests/test_hf_api.py
+15
-2
No files found.
src/transformers/hf_api.py
View file @
3e5da38d
...
@@ -17,7 +17,7 @@
...
@@ -17,7 +17,7 @@
import
io
import
io
import
os
import
os
from
os.path
import
expanduser
from
os.path
import
expanduser
from
typing
import
List
from
typing
import
Dict
,
List
,
Optional
import
requests
import
requests
from
tqdm
import
tqdm
from
tqdm
import
tqdm
...
@@ -27,6 +27,10 @@ ENDPOINT = "https://huggingface.co"
...
@@ -27,6 +27,10 @@ ENDPOINT = "https://huggingface.co"
class
S3Obj
:
class
S3Obj
:
"""
Data structure that represents a file belonging to the current user.
"""
def
__init__
(
self
,
filename
:
str
,
LastModified
:
str
,
ETag
:
str
,
Size
:
int
,
**
kwargs
):
def
__init__
(
self
,
filename
:
str
,
LastModified
:
str
,
ETag
:
str
,
Size
:
int
,
**
kwargs
):
self
.
filename
=
filename
self
.
filename
=
filename
self
.
LastModified
=
LastModified
self
.
LastModified
=
LastModified
...
@@ -41,6 +45,50 @@ class PresignedUrl:
...
@@ -41,6 +45,50 @@ class PresignedUrl:
self
.
type
=
type
# mime-type to send to S3.
self
.
type
=
type
# mime-type to send to S3.
class
S3Object
:
"""
Data structure that represents a public file accessible on our S3.
"""
def
__init__
(
self
,
key
:
str
,
# S3 object key
etag
:
str
,
lastModified
:
str
,
size
:
int
,
rfilename
:
str
,
# filename relative to config.json
**
kwargs
):
self
.
key
=
key
self
.
etag
=
etag
self
.
lastModified
=
lastModified
self
.
size
=
size
self
.
rfilename
=
rfilename
class
ModelInfo
:
"""
Info about a public model accessible from our S3.
"""
def
__init__
(
self
,
modelId
:
str
,
# id of model
key
:
str
,
# S3 object key of config.json
author
:
Optional
[
str
]
=
None
,
downloads
:
Optional
[
int
]
=
None
,
tags
:
List
[
str
]
=
[],
siblings
:
List
[
Dict
]
=
[],
# list of files that constitute the model
**
kwargs
):
self
.
modelId
=
modelId
self
.
key
=
key
self
.
author
=
author
self
.
downloads
=
downloads
self
.
tags
=
tags
self
.
siblings
=
[
S3Object
(
**
x
)
for
x
in
siblings
]
class
HfApi
:
class
HfApi
:
def
__init__
(
self
,
endpoint
=
None
):
def
__init__
(
self
,
endpoint
=
None
):
self
.
endpoint
=
endpoint
if
endpoint
is
not
None
else
ENDPOINT
self
.
endpoint
=
endpoint
if
endpoint
is
not
None
else
ENDPOINT
...
@@ -129,6 +177,16 @@ class HfApi:
...
@@ -129,6 +177,16 @@ class HfApi:
r
=
requests
.
delete
(
path
,
headers
=
{
"authorization"
:
"Bearer {}"
.
format
(
token
)},
json
=
{
"filename"
:
filename
})
r
=
requests
.
delete
(
path
,
headers
=
{
"authorization"
:
"Bearer {}"
.
format
(
token
)},
json
=
{
"filename"
:
filename
})
r
.
raise_for_status
()
r
.
raise_for_status
()
def
model_list
(
self
)
->
List
[
ModelInfo
]:
"""
Get the public list of all the models on huggingface, including the community models
"""
path
=
"{}/api/models"
.
format
(
self
.
endpoint
)
r
=
requests
.
get
(
path
)
r
.
raise_for_status
()
d
=
r
.
json
()
return
[
ModelInfo
(
**
x
)
for
x
in
d
]
class
TqdmProgressFileReader
:
class
TqdmProgressFileReader
:
"""
"""
...
...
tests/test_hf_api.py
View file @
3e5da38d
...
@@ -21,7 +21,7 @@ import unittest
...
@@ -21,7 +21,7 @@ import unittest
import
requests
import
requests
from
requests.exceptions
import
HTTPError
from
requests.exceptions
import
HTTPError
from
transformers.hf_api
import
HfApi
,
HfFolder
,
PresignedUrl
,
S3Obj
from
transformers.hf_api
import
HfApi
,
HfFolder
,
ModelInfo
,
PresignedUrl
,
S3Obj
USER
=
"__DUMMY_TRANSFORMERS_USER__"
USER
=
"__DUMMY_TRANSFORMERS_USER__"
...
@@ -36,10 +36,11 @@ FILES = [
...
@@ -36,10 +36,11 @@ FILES = [
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
"fixtures/empty.txt"
),
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
"fixtures/empty.txt"
),
),
),
]
]
ENDPOINT_STAGING
=
"https://moon-staging.huggingface.co"
class
HfApiCommonTest
(
unittest
.
TestCase
):
class
HfApiCommonTest
(
unittest
.
TestCase
):
_api
=
HfApi
(
endpoint
=
"https://moon-staging.huggingface.co"
)
_api
=
HfApi
(
endpoint
=
ENDPOINT_STAGING
)
class
HfApiLoginTest
(
HfApiCommonTest
):
class
HfApiLoginTest
(
HfApiCommonTest
):
...
@@ -92,6 +93,18 @@ class HfApiEndpointsTest(HfApiCommonTest):
...
@@ -92,6 +93,18 @@ class HfApiEndpointsTest(HfApiCommonTest):
self
.
assertIsInstance
(
o
,
S3Obj
)
self
.
assertIsInstance
(
o
,
S3Obj
)
class
HfApiPublicTest
(
unittest
.
TestCase
):
def
test_staging_model_list
(
self
):
_api
=
HfApi
(
endpoint
=
ENDPOINT_STAGING
)
_
=
_api
.
model_list
()
def
test_model_list
(
self
):
_api
=
HfApi
()
models
=
_api
.
model_list
()
self
.
assertGreater
(
len
(
models
),
100
)
self
.
assertIsInstance
(
models
[
0
],
ModelInfo
)
class
HfFolderTest
(
unittest
.
TestCase
):
class
HfFolderTest
(
unittest
.
TestCase
):
def
test_token_workflow
(
self
):
def
test_token_workflow
(
self
):
"""
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment