"vllm/vscode:/vscode.git/clone" did not exist on "9ef3b718d98ada5c02ebe67c2a545d3005a6b10d"
test_encoder_decoder.py 1.85 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import openai
import pytest
6
import os
7
import pytest_asyncio
8

9
from ...utils import RemoteOpenAIServer, models_path_prefix
10
from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP
zhuwenwen's avatar
zhuwenwen committed
11
from vllm.platforms import current_platform
12

13
MODEL_NAME = os.path.join(models_path_prefix, "facebook/bart-base")
14
15
16
17
18
19
20
21
22
23
24
25
26
27


@pytest.fixture(scope="module")
def server():
    args = [
        "--dtype",
        "bfloat16",
        "--enforce-eager",
    ]

    with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
        yield remote_server


28
29
30
31
@pytest_asyncio.fixture
async def client(server):
    async with server.get_async_client() as async_client:
        yield async_client
32
33


zhuwenwen's avatar
zhuwenwen committed
34
@pytest.mark.skipif(current_platform.is_rocm(),
35
                    reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP)
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_single_completion(client: openai.AsyncOpenAI, model_name: str):
    completion = await client.completions.create(model=model_name,
                                                 prompt="Hello, my name is",
                                                 max_tokens=5,
                                                 temperature=0.0)

    assert completion.id is not None
    assert completion.choices is not None and len(completion.choices) == 1

    choice = completion.choices[0]
    assert len(choice.text) >= 5
    assert choice.finish_reason == "length"
    assert completion.usage == openai.types.CompletionUsage(
        completion_tokens=5, prompt_tokens=2, total_tokens=7)

    # test using token IDs
    completion = await client.completions.create(
        model=model_name,
        prompt=[0, 0, 0, 0, 0],
        max_tokens=5,
        temperature=0.0,
    )
    assert len(completion.choices[0].text) >= 1