test_inference_api.py 1.1 KB
Newer Older
1
2
3
4
5
6
7
8
import pytest

from text_generation import (
    InferenceAPIClient,
    InferenceAPIAsyncClient,
    Client,
    AsyncClient,
)
9
10
from text_generation.errors import NotSupportedError, NotFoundError
from text_generation.inference_api import check_model_support, deployed_models
11
12


13
14
15
16
17
18
19
20
21
22
def test_check_model_support(flan_t5_xxl, unsupported_model, fake_model):
    assert check_model_support(flan_t5_xxl)
    assert not check_model_support(unsupported_model)

    with pytest.raises(NotFoundError):
        check_model_support(fake_model)


def test_deployed_models():
    deployed_models()
23
24


25
26
def test_client(flan_t5_xxl):
    client = InferenceAPIClient(flan_t5_xxl)
27
28
29
30
31
32
33
34
    assert isinstance(client, Client)


def test_client_unsupported_model(unsupported_model):
    with pytest.raises(NotSupportedError):
        InferenceAPIClient(unsupported_model)


35
36
def test_async_client(flan_t5_xxl):
    client = InferenceAPIAsyncClient(flan_t5_xxl)
37
38
39
40
41
42
    assert isinstance(client, AsyncClient)


def test_async_client_unsupported_model(unsupported_model):
    with pytest.raises(NotSupportedError):
        InferenceAPIAsyncClient(unsupported_model)