test_seeded_generate.py 2.51 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
Nick Hill's avatar
Nick Hill committed
3
4
"""Verify that seeded random sampling is deterministic.

5
Run `pytest tests/samplers/test_seeded_generate.py`.
Nick Hill's avatar
Nick Hill committed
6
7
8
9
10
11
12
13
"""
import copy
import random
from itertools import combinations

import pytest

from vllm import SamplingParams
14
from vllm.model_executor.utils import set_random_seed
Nick Hill's avatar
Nick Hill committed
15
16
17
18
19
20

MODEL = "facebook/opt-125m"
RANDOM_SEEDS = list(range(5))


@pytest.fixture
21
22
23
def vllm_model(vllm_runner, monkeypatch):
    # This file relies on V0 internals.
    monkeypatch.setenv("VLLM_USE_V1", "0")
24
25
    with vllm_runner(MODEL, dtype="half") as vllm_model:
        yield vllm_model
Nick Hill's avatar
Nick Hill committed
26
27
28
29
30
31
32
33
34
35
36
37


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
def test_random_sample_with_seed(
    vllm_model,
    example_prompts,
    seed: int,
) -> None:
    set_random_seed(seed)

    sampling_params = SamplingParams(
        # Parameters to ensure sufficient randomness
38
        temperature=3.0,
Nick Hill's avatar
Nick Hill committed
39
40
41
42
43
44
45
46
47
48
49
50
51
        top_p=min(random.random() + 0.3, 1),
        top_k=random.randint(5, 20),
        n=random.randint(1, 10),
        presence_penalty=random.randint(0, 1),
        max_tokens=8,
        ignore_eos=True,
    )

    sampling_params_seed_1 = copy.deepcopy(sampling_params)
    sampling_params_seed_1.seed = 100
    sampling_params_seed_2 = copy.deepcopy(sampling_params)
    sampling_params_seed_2.seed = 200

52
    llm = vllm_model.llm
Nick Hill's avatar
Nick Hill committed
53
54
55
56
57
58
59
60
61
62

    for prompt in example_prompts:
        for params in (
                sampling_params,
                sampling_params_seed_1,
                sampling_params_seed_2,
                sampling_params,
                sampling_params_seed_1,
                sampling_params_seed_2,
        ):
63
            llm._add_request(prompt, params=params)
Nick Hill's avatar
Nick Hill committed
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81

    results = llm._run_engine(use_tqdm=False)
    all_outputs = [[out.token_ids for out in output.outputs]
                   for output in results]

    for i in range(0, len(example_prompts), 6):
        outputs = all_outputs[i:i + 6]

        # verify all non-seeded requests differ
        for output_a, output_b in combinations(
            (outputs[0], outputs[1], outputs[2], outputs[3]),
                2,
        ):
            assert output_a != output_b

        # verify requests with the same seed match
        assert outputs[1] == outputs[4]
        assert outputs[2] == outputs[5]
82
83
84
85
86

        # verify generations within the same parallel sampling group differ
        for output in outputs:
            for sub_output_a, sub_output_b in combinations(output, 2):
                assert sub_output_a != sub_output_b