test_stop_reason.py 1.97 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
"""Test the different finish_reason="stop" situations during generation:
    1. One of the provided stop strings
    2. One of the provided stop tokens
    3. The EOS token

8
Run `pytest tests/engine/test_stop_reason.py`.
9
10
11
"""

import pytest
12
import os
13
14
15
import transformers

from vllm import SamplingParams
16
from ..utils import models_path_prefix
17

zhuwenwen's avatar
zhuwenwen committed
18
MODEL = os.path.join(models_path_prefix, "distilbert/distilgpt2")
19
20
21
22
23
24
25
STOP_STR = "."
SEED = 42
MAX_TOKENS = 1024


@pytest.fixture
def vllm_model(vllm_runner):
26
27
    with vllm_runner(MODEL) as vllm_model:
        yield vllm_model
28
29
30
31
32


def test_stop_reason(vllm_model, example_prompts):
    tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL)
    stop_token_id = tokenizer.convert_tokens_to_ids(STOP_STR)
33
    llm = vllm_model.llm
34
35

    # test stop token
36
37
38
39
40
41
42
43
44
    outputs = llm.generate(
        example_prompts,
        sampling_params=SamplingParams(
            ignore_eos=True,
            seed=SEED,
            max_tokens=MAX_TOKENS,
            stop_token_ids=[stop_token_id],
        ),
    )
45
46
47
48
49
50
    for output in outputs:
        output = output.outputs[0]
        assert output.finish_reason == "stop"
        assert output.stop_reason == stop_token_id

    # test stop string
51
52
53
54
55
56
    outputs = llm.generate(
        example_prompts,
        sampling_params=SamplingParams(
            ignore_eos=True, seed=SEED, max_tokens=MAX_TOKENS, stop="."
        ),
    )
57
58
59
60
61
62
    for output in outputs:
        output = output.outputs[0]
        assert output.finish_reason == "stop"
        assert output.stop_reason == STOP_STR

    # test EOS token
63
64
65
66
    outputs = llm.generate(
        example_prompts,
        sampling_params=SamplingParams(seed=SEED, max_tokens=MAX_TOKENS),
    )
67
68
69
    for output in outputs:
        output = output.outputs[0]
        assert output.finish_reason == "length" or (
70
71
            output.finish_reason == "stop" and output.stop_reason is None
        )