"docs/sphinx/requirements.txt" did not exist on "799f1c49ba85ec8c94ea24f0cf1c57b655658e5a"
test_stop_reason.py 1.83 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
"""Test the different finish_reason="stop" situations during generation:
    1. One of the provided stop strings
    2. One of the provided stop tokens
    3. The EOS token

Run `pytest tests/engine/test_stop_reason.py`.
"""

import pytest
import transformers

from vllm import SamplingParams

MODEL = "facebook/opt-350m"
STOP_STR = "."
SEED = 42
MAX_TOKENS = 1024


@pytest.fixture
def vllm_model(vllm_runner):
    vllm_model = vllm_runner(MODEL)
    yield vllm_model
    del vllm_model


def test_stop_reason(vllm_model, example_prompts):
    tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL)
    stop_token_id = tokenizer.convert_tokens_to_ids(STOP_STR)
    llm = vllm_model.model

    # test stop token
    outputs = llm.generate(example_prompts,
                           sampling_params=SamplingParams(
                               seed=SEED,
                               max_tokens=MAX_TOKENS,
                               stop_token_ids=[stop_token_id]))
    for output in outputs:
        output = output.outputs[0]
        assert output.finish_reason == "stop"
        assert output.stop_reason == stop_token_id

    # test stop string
    outputs = llm.generate(example_prompts,
                           sampling_params=SamplingParams(
                               seed=SEED, max_tokens=MAX_TOKENS, stop="."))
    for output in outputs:
        output = output.outputs[0]
        assert output.finish_reason == "stop"
        assert output.stop_reason == STOP_STR

    # test EOS token
    outputs = llm.generate(example_prompts,
                           sampling_params=SamplingParams(
                               seed=SEED, max_tokens=MAX_TOKENS))
    for output in outputs:
        output = output.outputs[0]
        assert output.finish_reason == "length" or (
            output.finish_reason == "stop" and output.stop_reason is None)