Unverified Commit 323546df authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

fix(python-client): add auth headers to is supported requests (#234)

parent 37b64a5c
[tool.poetry] [tool.poetry]
name = "text-generation" name = "text-generation"
version = "0.5.0" version = "0.5.1"
description = "Hugging Face Text Generation Python Client" description = "Hugging Face Text Generation Python Client"
license = "Apache-2.0" license = "Apache-2.0"
authors = ["Olivier Dehaene <olivier@huggingface.co>"] authors = ["Olivier Dehaene <olivier@huggingface.co>"]
......
import os import os
import requests import requests
from typing import Optional, List from typing import Dict, Optional, List
from huggingface_hub.utils import build_hf_headers from huggingface_hub.utils import build_hf_headers
from text_generation import Client, AsyncClient, __version__ from text_generation import Client, AsyncClient, __version__
...@@ -13,7 +13,7 @@ INFERENCE_ENDPOINT = os.environ.get( ...@@ -13,7 +13,7 @@ INFERENCE_ENDPOINT = os.environ.get(
) )
def deployed_models() -> List[DeployedModel]: def deployed_models(headers: Optional[Dict] = None) -> List[DeployedModel]:
""" """
Get all currently deployed models with text-generation-inference-support Get all currently deployed models with text-generation-inference-support
...@@ -22,6 +22,7 @@ def deployed_models() -> List[DeployedModel]: ...@@ -22,6 +22,7 @@ def deployed_models() -> List[DeployedModel]:
""" """
resp = requests.get( resp = requests.get(
f"https://api-inference.huggingface.co/framework/text-generation-inference", f"https://api-inference.huggingface.co/framework/text-generation-inference",
headers=headers,
timeout=5, timeout=5,
) )
...@@ -33,7 +34,7 @@ def deployed_models() -> List[DeployedModel]: ...@@ -33,7 +34,7 @@ def deployed_models() -> List[DeployedModel]:
return models return models
def check_model_support(repo_id: str) -> bool: def check_model_support(repo_id: str, headers: Optional[Dict] = None) -> bool:
""" """
Check if a given model is supported by text-generation-inference Check if a given model is supported by text-generation-inference
...@@ -42,6 +43,7 @@ def check_model_support(repo_id: str) -> bool: ...@@ -42,6 +43,7 @@ def check_model_support(repo_id: str) -> bool:
""" """
resp = requests.get( resp = requests.get(
f"https://api-inference.huggingface.co/status/{repo_id}", f"https://api-inference.huggingface.co/status/{repo_id}",
headers=headers,
timeout=5, timeout=5,
) )
...@@ -95,13 +97,14 @@ class InferenceAPIClient(Client): ...@@ -95,13 +97,14 @@ class InferenceAPIClient(Client):
Timeout in seconds Timeout in seconds
""" """
# Text Generation Inference client only supports a subset of the available hub models
if not check_model_support(repo_id):
raise NotSupportedError(repo_id)
headers = build_hf_headers( headers = build_hf_headers(
token=token, library_name="text-generation", library_version=__version__ token=token, library_name="text-generation", library_version=__version__
) )
# Text Generation Inference client only supports a subset of the available hub models
if not check_model_support(repo_id, headers):
raise NotSupportedError(repo_id)
base_url = f"{INFERENCE_ENDPOINT}/models/{repo_id}" base_url = f"{INFERENCE_ENDPOINT}/models/{repo_id}"
super(InferenceAPIClient, self).__init__( super(InferenceAPIClient, self).__init__(
...@@ -150,14 +153,14 @@ class InferenceAPIAsyncClient(AsyncClient): ...@@ -150,14 +153,14 @@ class InferenceAPIAsyncClient(AsyncClient):
timeout (`int`): timeout (`int`):
Timeout in seconds Timeout in seconds
""" """
headers = build_hf_headers(
token=token, library_name="text-generation", library_version=__version__
)
# Text Generation Inference client only supports a subset of the available hub models # Text Generation Inference client only supports a subset of the available hub models
if not check_model_support(repo_id): if not check_model_support(repo_id, headers):
raise NotSupportedError(repo_id) raise NotSupportedError(repo_id)
headers = build_hf_headers(
token=token, library_name="text-generation", library_version=__version__
)
base_url = f"{INFERENCE_ENDPOINT}/models/{repo_id}" base_url = f"{INFERENCE_ENDPOINT}/models/{repo_id}"
super(InferenceAPIAsyncClient, self).__init__( super(InferenceAPIAsyncClient, self).__init__(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment