Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
323546df
Unverified
Commit
323546df
authored
Apr 25, 2023
by
OlivierDehaene
Committed by
GitHub
Apr 25, 2023
Browse files
fix(python-client): add auth headers to is supported requests (#234)
parent
37b64a5c
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
15 additions
and
12 deletions
+15
-12
clients/python/pyproject.toml
clients/python/pyproject.toml
+1
-1
clients/python/text_generation/inference_api.py
clients/python/text_generation/inference_api.py
+14
-11
No files found.
clients/python/pyproject.toml
View file @
323546df
[tool.poetry]
[tool.poetry]
name
=
"text-generation"
name
=
"text-generation"
version
=
"0.5.
0
"
version
=
"0.5.
1
"
description
=
"Hugging Face Text Generation Python Client"
description
=
"Hugging Face Text Generation Python Client"
license
=
"Apache-2.0"
license
=
"Apache-2.0"
authors
=
[
"Olivier Dehaene <olivier@huggingface.co>"
]
authors
=
[
"Olivier Dehaene <olivier@huggingface.co>"
]
...
...
clients/python/text_generation/inference_api.py
View file @
323546df
import
os
import
os
import
requests
import
requests
from
typing
import
Optional
,
List
from
typing
import
Dict
,
Optional
,
List
from
huggingface_hub.utils
import
build_hf_headers
from
huggingface_hub.utils
import
build_hf_headers
from
text_generation
import
Client
,
AsyncClient
,
__version__
from
text_generation
import
Client
,
AsyncClient
,
__version__
...
@@ -13,7 +13,7 @@ INFERENCE_ENDPOINT = os.environ.get(
...
@@ -13,7 +13,7 @@ INFERENCE_ENDPOINT = os.environ.get(
)
)
def
deployed_models
()
->
List
[
DeployedModel
]:
def
deployed_models
(
headers
:
Optional
[
Dict
]
=
None
)
->
List
[
DeployedModel
]:
"""
"""
Get all currently deployed models with text-generation-inference-support
Get all currently deployed models with text-generation-inference-support
...
@@ -22,6 +22,7 @@ def deployed_models() -> List[DeployedModel]:
...
@@ -22,6 +22,7 @@ def deployed_models() -> List[DeployedModel]:
"""
"""
resp
=
requests
.
get
(
resp
=
requests
.
get
(
f
"https://api-inference.huggingface.co/framework/text-generation-inference"
,
f
"https://api-inference.huggingface.co/framework/text-generation-inference"
,
headers
=
headers
,
timeout
=
5
,
timeout
=
5
,
)
)
...
@@ -33,7 +34,7 @@ def deployed_models() -> List[DeployedModel]:
...
@@ -33,7 +34,7 @@ def deployed_models() -> List[DeployedModel]:
return
models
return
models
def
check_model_support
(
repo_id
:
str
)
->
bool
:
def
check_model_support
(
repo_id
:
str
,
headers
:
Optional
[
Dict
]
=
None
)
->
bool
:
"""
"""
Check if a given model is supported by text-generation-inference
Check if a given model is supported by text-generation-inference
...
@@ -42,6 +43,7 @@ def check_model_support(repo_id: str) -> bool:
...
@@ -42,6 +43,7 @@ def check_model_support(repo_id: str) -> bool:
"""
"""
resp
=
requests
.
get
(
resp
=
requests
.
get
(
f
"https://api-inference.huggingface.co/status/
{
repo_id
}
"
,
f
"https://api-inference.huggingface.co/status/
{
repo_id
}
"
,
headers
=
headers
,
timeout
=
5
,
timeout
=
5
,
)
)
...
@@ -95,13 +97,14 @@ class InferenceAPIClient(Client):
...
@@ -95,13 +97,14 @@ class InferenceAPIClient(Client):
Timeout in seconds
Timeout in seconds
"""
"""
# Text Generation Inference client only supports a subset of the available hub models
if
not
check_model_support
(
repo_id
):
raise
NotSupportedError
(
repo_id
)
headers
=
build_hf_headers
(
headers
=
build_hf_headers
(
token
=
token
,
library_name
=
"text-generation"
,
library_version
=
__version__
token
=
token
,
library_name
=
"text-generation"
,
library_version
=
__version__
)
)
# Text Generation Inference client only supports a subset of the available hub models
if
not
check_model_support
(
repo_id
,
headers
):
raise
NotSupportedError
(
repo_id
)
base_url
=
f
"
{
INFERENCE_ENDPOINT
}
/models/
{
repo_id
}
"
base_url
=
f
"
{
INFERENCE_ENDPOINT
}
/models/
{
repo_id
}
"
super
(
InferenceAPIClient
,
self
).
__init__
(
super
(
InferenceAPIClient
,
self
).
__init__
(
...
@@ -150,14 +153,14 @@ class InferenceAPIAsyncClient(AsyncClient):
...
@@ -150,14 +153,14 @@ class InferenceAPIAsyncClient(AsyncClient):
timeout (`int`):
timeout (`int`):
Timeout in seconds
Timeout in seconds
"""
"""
headers
=
build_hf_headers
(
token
=
token
,
library_name
=
"text-generation"
,
library_version
=
__version__
)
# Text Generation Inference client only supports a subset of the available hub models
# Text Generation Inference client only supports a subset of the available hub models
if
not
check_model_support
(
repo_id
):
if
not
check_model_support
(
repo_id
,
headers
):
raise
NotSupportedError
(
repo_id
)
raise
NotSupportedError
(
repo_id
)
headers
=
build_hf_headers
(
token
=
token
,
library_name
=
"text-generation"
,
library_version
=
__version__
)
base_url
=
f
"
{
INFERENCE_ENDPOINT
}
/models/
{
repo_id
}
"
base_url
=
f
"
{
INFERENCE_ENDPOINT
}
/models/
{
repo_id
}
"
super
(
InferenceAPIAsyncClient
,
self
).
__init__
(
super
(
InferenceAPIAsyncClient
,
self
).
__init__
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment