"vllm/model_executor/model_loader.py" did not exist on "7c041ab5784760416f85d68eb8925a1d1f981932"
test_sampling_params.py 3.82 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

"""Unit tests for ResponsesRequest.to_sampling_params() parameter mapping."""

import pytest

from vllm.entrypoints.openai.responses.protocol import ResponsesRequest


class TestResponsesRequestSamplingParams:
    """Test that ResponsesRequest correctly maps parameters to SamplingParams."""

    def test_basic_sampling_params(self):
        """Test basic sampling parameters are correctly mapped."""
        request = ResponsesRequest(
            model="test-model",
            input="test input",
            temperature=0.8,
            top_p=0.95,
            top_k=50,
            max_output_tokens=100,
        )

        sampling_params = request.to_sampling_params(default_max_tokens=1000)

        assert sampling_params.temperature == 0.8
        assert sampling_params.top_p == 0.95
        assert sampling_params.top_k == 50
        assert sampling_params.max_tokens == 100

    def test_extra_sampling_params(self):
        """Test extra sampling parameters are correctly mapped."""
        request = ResponsesRequest(
            model="test-model",
            input="test input",
            repetition_penalty=1.2,
            seed=42,
            stop=["END", "STOP"],
            ignore_eos=True,
            vllm_xargs={"custom": "value"},
        )

        sampling_params = request.to_sampling_params(default_max_tokens=1000)

        assert sampling_params.repetition_penalty == 1.2
        assert sampling_params.seed == 42
        assert sampling_params.stop == ["END", "STOP"]
        assert sampling_params.ignore_eos is True
        assert sampling_params.extra_args == {"custom": "value"}

    def test_stop_string_conversion(self):
        """Test that single stop string is converted to list."""
        request = ResponsesRequest(
            model="test-model",
            input="test input",
            stop="STOP",
        )

        sampling_params = request.to_sampling_params(default_max_tokens=1000)

        assert sampling_params.stop == ["STOP"]

    def test_default_values(self):
        """Test default values for optional parameters."""
        request = ResponsesRequest(
            model="test-model",
            input="test input",
        )

        sampling_params = request.to_sampling_params(default_max_tokens=1000)

        assert sampling_params.repetition_penalty == 1.0  # None → 1.0
        assert sampling_params.stop == []  # Empty list
        assert sampling_params.extra_args == {}  # Empty dict

    def test_seed_bounds_validation(self):
        """Test that seed values outside torch.long bounds are rejected."""
        import torch
        from pydantic import ValidationError

        # Test seed below minimum
        with pytest.raises(ValidationError) as exc_info:
            ResponsesRequest(
                model="test-model",
                input="test input",
                seed=torch.iinfo(torch.long).min - 1,
            )
        assert "greater_than_equal" in str(exc_info.value).lower()

        # Test seed above maximum
        with pytest.raises(ValidationError) as exc_info:
            ResponsesRequest(
                model="test-model",
                input="test input",
                seed=torch.iinfo(torch.long).max + 1,
            )
        assert "less_than_equal" in str(exc_info.value).lower()

        # Test valid seed at boundaries
        request_min = ResponsesRequest(
            model="test-model",
            input="test input",
            seed=torch.iinfo(torch.long).min,
        )
        assert request_min.seed == torch.iinfo(torch.long).min

        request_max = ResponsesRequest(
            model="test-model",
            input="test input",
            seed=torch.iinfo(torch.long).max,
        )
        assert request_max.seed == torch.iinfo(torch.long).max