Unverified Commit 3e5da38d authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #3132 from huggingface/hf_api_model_list

[hf_api] Get the public list of all the models on huggingface
parents 9499a377 3f067f44
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import io import io
import os import os
from os.path import expanduser from os.path import expanduser
from typing import List from typing import Dict, List, Optional
import requests import requests
from tqdm import tqdm from tqdm import tqdm
...@@ -27,6 +27,10 @@ ENDPOINT = "https://huggingface.co" ...@@ -27,6 +27,10 @@ ENDPOINT = "https://huggingface.co"
class S3Obj: class S3Obj:
"""
Data structure that represents a file belonging to the current user.
"""
def __init__(self, filename: str, LastModified: str, ETag: str, Size: int, **kwargs): def __init__(self, filename: str, LastModified: str, ETag: str, Size: int, **kwargs):
self.filename = filename self.filename = filename
self.LastModified = LastModified self.LastModified = LastModified
...@@ -41,6 +45,50 @@ class PresignedUrl: ...@@ -41,6 +45,50 @@ class PresignedUrl:
self.type = type # mime-type to send to S3. self.type = type # mime-type to send to S3.
class S3Object:
"""
Data structure that represents a public file accessible on our S3.
"""
def __init__(
self,
key: str, # S3 object key
etag: str,
lastModified: str,
size: int,
rfilename: str, # filename relative to config.json
**kwargs
):
self.key = key
self.etag = etag
self.lastModified = lastModified
self.size = size
self.rfilename = rfilename
class ModelInfo:
"""
Info about a public model accessible from our S3.
"""
def __init__(
self,
modelId: str, # id of model
key: str, # S3 object key of config.json
author: Optional[str] = None,
downloads: Optional[int] = None,
tags: List[str] = [],
siblings: List[Dict] = [], # list of files that constitute the model
**kwargs
):
self.modelId = modelId
self.key = key
self.author = author
self.downloads = downloads
self.tags = tags
self.siblings = [S3Object(**x) for x in siblings]
class HfApi: class HfApi:
def __init__(self, endpoint=None): def __init__(self, endpoint=None):
self.endpoint = endpoint if endpoint is not None else ENDPOINT self.endpoint = endpoint if endpoint is not None else ENDPOINT
...@@ -129,6 +177,16 @@ class HfApi: ...@@ -129,6 +177,16 @@ class HfApi:
r = requests.delete(path, headers={"authorization": "Bearer {}".format(token)}, json={"filename": filename}) r = requests.delete(path, headers={"authorization": "Bearer {}".format(token)}, json={"filename": filename})
r.raise_for_status() r.raise_for_status()
def model_list(self) -> List[ModelInfo]:
"""
Get the public list of all the models on huggingface, including the community models
"""
path = "{}/api/models".format(self.endpoint)
r = requests.get(path)
r.raise_for_status()
d = r.json()
return [ModelInfo(**x) for x in d]
class TqdmProgressFileReader: class TqdmProgressFileReader:
""" """
......
...@@ -21,7 +21,7 @@ import unittest ...@@ -21,7 +21,7 @@ import unittest
import requests import requests
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
from transformers.hf_api import HfApi, HfFolder, PresignedUrl, S3Obj from transformers.hf_api import HfApi, HfFolder, ModelInfo, PresignedUrl, S3Obj
USER = "__DUMMY_TRANSFORMERS_USER__" USER = "__DUMMY_TRANSFORMERS_USER__"
...@@ -36,10 +36,11 @@ FILES = [ ...@@ -36,10 +36,11 @@ FILES = [
os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/empty.txt"), os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/empty.txt"),
), ),
] ]
ENDPOINT_STAGING = "https://moon-staging.huggingface.co"
class HfApiCommonTest(unittest.TestCase): class HfApiCommonTest(unittest.TestCase):
_api = HfApi(endpoint="https://moon-staging.huggingface.co") _api = HfApi(endpoint=ENDPOINT_STAGING)
class HfApiLoginTest(HfApiCommonTest): class HfApiLoginTest(HfApiCommonTest):
...@@ -92,6 +93,18 @@ class HfApiEndpointsTest(HfApiCommonTest): ...@@ -92,6 +93,18 @@ class HfApiEndpointsTest(HfApiCommonTest):
self.assertIsInstance(o, S3Obj) self.assertIsInstance(o, S3Obj)
class HfApiPublicTest(unittest.TestCase):
def test_staging_model_list(self):
_api = HfApi(endpoint=ENDPOINT_STAGING)
_ = _api.model_list()
def test_model_list(self):
_api = HfApi()
models = _api.model_list()
self.assertGreater(len(models), 100)
self.assertIsInstance(models[0], ModelInfo)
class HfFolderTest(unittest.TestCase): class HfFolderTest(unittest.TestCase):
def test_token_workflow(self): def test_token_workflow(self):
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment