Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
b927244e
Unverified
Commit
b927244e
authored
Apr 17, 2023
by
OlivierDehaene
Committed by
GitHub
Apr 17, 2023
Browse files
feat(python-client): get list of currently deployed tgi models using the inference API (#191)
parent
c13b9d87
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
69 additions
and
31 deletions
+69
-31
README.md
README.md
+1
-1
clients/python/README.md
clients/python/README.md
+13
-0
clients/python/pyproject.toml
clients/python/pyproject.toml
+1
-1
clients/python/tests/test_inference_api.py
clients/python/tests/test_inference_api.py
+12
-4
clients/python/text_generation/inference_api.py
clients/python/text_generation/inference_api.py
+36
-25
clients/python/text_generation/types.py
clients/python/text_generation/types.py
+6
-0
No files found.
README.md
View file @
b927244e
...
...
@@ -22,7 +22,7 @@ to power LLMs api-inference widgets.
## Table of contents
-
[
Features
](
#features
)
-
[
O
fficially Supported Models
](
#officially-supported-model
s
)
-
[
O
ptimized Architectures
](
#optimized-architecture
s
)
-
[
Get Started
](
#get-started
)
-
[
Docker
](
#docker
)
-
[
API Documentation
](
#api-documentation
)
...
...
clients/python/README.md
View file @
b927244e
...
...
@@ -52,6 +52,14 @@ print(text)
# ' Rayleigh scattering'
```
Check all currently deployed models on the Huggingface Inference API with
`Text Generation`
support:
```
python
from
text_generation.inference_api
import
deployed_models
print
(
deployed_models
())
```
### Hugging Face Inference Endpoint usage
```
python
...
...
@@ -193,4 +201,9 @@ class StreamResponse:
# Generation details
# Only available when the generation is finished
details
:
Optional
[
StreamDetails
]
# Inference API currently deployed model
class
DeployedModel
:
model_id
:
str
sha
:
str
```
\ No newline at end of file
clients/python/pyproject.toml
View file @
b927244e
[tool.poetry]
name
=
"text-generation"
version
=
"0.
4.1
"
version
=
"0.
5.0
"
description
=
"Hugging Face Text Generation Python Client"
license
=
"Apache-2.0"
authors
=
[
"Olivier Dehaene <olivier@huggingface.co>"
]
...
...
clients/python/tests/test_inference_api.py
View file @
b927244e
...
...
@@ -6,12 +6,20 @@ from text_generation import (
Client
,
AsyncClient
,
)
from
text_generation.errors
import
NotSupportedError
from
text_generation.inference_api
import
get_support
ed_models
from
text_generation.errors
import
NotSupportedError
,
NotFoundError
from
text_generation.inference_api
import
check_model_support
,
deploy
ed_models
def
test_get_supported_models
():
assert
isinstance
(
get_supported_models
(),
list
)
def
test_check_model_support
(
flan_t5_xxl
,
unsupported_model
,
fake_model
):
assert
check_model_support
(
flan_t5_xxl
)
assert
not
check_model_support
(
unsupported_model
)
with
pytest
.
raises
(
NotFoundError
):
check_model_support
(
fake_model
)
def
test_deployed_models
():
deployed_models
()
def
test_client
(
flan_t5_xxl
):
...
...
clients/python/text_generation/inference_api.py
View file @
b927244e
import
os
import
requests
import
base64
import
json
import
warnings
from
typing
import
List
,
Optional
from
typing
import
Optional
,
List
from
huggingface_hub.utils
import
build_hf_headers
from
text_generation
import
Client
,
AsyncClient
,
__version__
from
text_generation.errors
import
NotSupportedError
from
text_generation.types
import
DeployedModel
from
text_generation.errors
import
NotSupportedError
,
parse_error
INFERENCE_ENDPOINT
=
os
.
environ
.
get
(
"HF_INFERENCE_ENDPOINT"
,
"https://api-inference.huggingface.co"
)
SUPPORTED_MODELS
=
None
def
get_supported_models
()
->
Optional
[
List
[
str
]]:
def
deployed_models
()
->
List
[
DeployedModel
]:
"""
Get
the list of supported text-generation models from GitHub
Get
all currently deployed models with text-generation-inference-support
Returns:
Optional[List[str]]: supported m
odel
s
list o
r None if unable to get the list from GitHub
List[DeployedM
odel
]:
list o
f all currently deployed models
"""
global
SUPPORTED_MODELS
if
SUPPORTED_MODELS
is
not
None
:
return
SUPPORTED_MODELS
resp
=
requests
.
get
(
f
"https://api-inference.huggingface.co/framework/text-generation-inference"
,
timeout
=
5
,
)
payload
=
resp
.
json
()
if
resp
.
status_code
!=
200
:
raise
parse_error
(
resp
.
status_code
,
payload
)
models
=
[
DeployedModel
(
**
raw_deployed_model
)
for
raw_deployed_model
in
payload
]
return
models
response
=
requests
.
get
(
"https://api.github.com/repos/huggingface/text-generation-inference/contents/supported_models.json"
,
def
check_model_support
(
repo_id
:
str
)
->
bool
:
"""
Check if a given model is supported by text-generation-inference
Returns:
bool: whether the model is supported by this client
"""
resp
=
requests
.
get
(
f
"https://api-inference.huggingface.co/status/
{
repo_id
}
"
,
timeout
=
5
,
)
if
response
.
status_code
==
200
:
file_content
=
response
.
json
()[
"content"
]
SUPPORTED_MODELS
=
json
.
loads
(
base64
.
b64decode
(
file_content
).
decode
(
"utf-8"
))
return
SUPPORTED_MODELS
warnings
.
warn
(
"Could not retrieve list of supported models."
)
return
None
payload
=
resp
.
json
()
if
resp
.
status_code
!=
200
:
raise
parse_error
(
resp
.
status_code
,
payload
)
framework
=
payload
[
"framework"
]
supported
=
framework
==
"text-generation-inference"
return
supported
class
InferenceAPIClient
(
Client
):
...
...
@@ -83,8 +96,7 @@ class InferenceAPIClient(Client):
"""
# Text Generation Inference client only supports a subset of the available hub models
supported_models
=
get_supported_models
()
if
supported_models
is
not
None
and
repo_id
not
in
supported_models
:
if
not
check_model_support
(
repo_id
):
raise
NotSupportedError
(
repo_id
)
headers
=
build_hf_headers
(
...
...
@@ -140,8 +152,7 @@ class InferenceAPIAsyncClient(AsyncClient):
"""
# Text Generation Inference client only supports a subset of the available hub models
supported_models
=
get_supported_models
()
if
supported_models
is
not
None
and
repo_id
not
in
supported_models
:
if
not
check_model_support
(
repo_id
):
raise
NotSupportedError
(
repo_id
)
headers
=
build_hf_headers
(
...
...
clients/python/text_generation/types.py
View file @
b927244e
...
...
@@ -223,3 +223,9 @@ class StreamResponse(BaseModel):
# Generation details
# Only available when the generation is finished
details
:
Optional
[
StreamDetails
]
# Inference API currently deployed model
class
DeployedModel
(
BaseModel
):
model_id
:
str
sha
:
str
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment