test_gpu_profiler.py 5.77 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest

5
6
from vllm.config import ProfilerConfig
from vllm.profiler.wrapper import WorkerProfiler
7
8
9
10
11
12
13


class ConcreteWorkerProfiler(WorkerProfiler):
    """
    A basic implementation of a worker profiler for testing purposes.
    """

14
    def __init__(self, profiler_config: ProfilerConfig):
15
16
17
        self.start_call_count = 0
        self.stop_call_count = 0
        self.should_fail_start = False
18
        super().__init__(profiler_config)
19
20
21
22
23
24
25
26
27
28

    def _start(self) -> None:
        if self.should_fail_start:
            raise RuntimeError("Simulated start failure")
        self.start_call_count += 1

    def _stop(self) -> None:
        self.stop_call_count += 1


29
30
31
32
33
34
35
36
@pytest.fixture
def default_profiler_config():
    return ProfilerConfig(
        profiler="torch",
        torch_profiler_dir="/tmp/mock",
        delay_iterations=0,
        max_iterations=0,
    )
37
38


39
def test_immediate_start_stop(default_profiler_config):
40
    """Test standard start without delay."""
41
    profiler = ConcreteWorkerProfiler(default_profiler_config)
42
43
44
45
46
47
48
49
50
51
52
    profiler.start()
    assert profiler._running is True
    assert profiler._active is True
    assert profiler.start_call_count == 1

    profiler.stop()
    assert profiler._running is False
    assert profiler._active is False
    assert profiler.stop_call_count == 1


53
def test_delayed_start(default_profiler_config):
54
    """Test that profiler waits for N steps before actually starting."""
55
56
    default_profiler_config.delay_iterations = 2
    profiler = ConcreteWorkerProfiler(default_profiler_config)
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75

    # User requests start
    profiler.start()

    # Should be active (request accepted) but not running (waiting for delay)
    assert profiler._active is True
    assert profiler._running is False
    assert profiler.start_call_count == 0

    # Step 1
    profiler.step()
    assert profiler._running is False

    # Step 2 (Threshold reached)
    profiler.step()
    assert profiler._running is True
    assert profiler.start_call_count == 1


76
def test_max_iterations(default_profiler_config):
77
    """Test that profiler stops automatically after max iterations."""
78
79
    default_profiler_config.max_iterations = 2
    profiler = ConcreteWorkerProfiler(default_profiler_config)
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99

    profiler.start()
    assert profiler._running is True

    # Iteration 1
    profiler.step()  # profiling_count becomes 1
    assert profiler._running is True

    # Iteration 2
    profiler.step()  # profiling_count becomes 2
    assert profiler._running is True

    # Iteration 3 (Exceeds max)
    profiler.step()  # profiling_count becomes 3

    # Should have stopped now
    assert profiler._running is False
    assert profiler.stop_call_count == 1


100
def test_delayed_start_and_max_iters(default_profiler_config):
101
    """Test combined delayed start and max iterations."""
102
103
104
    default_profiler_config.delay_iterations = 2
    default_profiler_config.max_iterations = 2
    profiler = ConcreteWorkerProfiler(default_profiler_config)
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
    profiler.start()

    # Step 1
    profiler.step()
    assert profiler._running is False
    assert profiler._active is True

    # Step 2 (Starts now)
    profiler.step()
    assert profiler._profiling_for_iters == 1
    assert profiler._running is True
    assert profiler._active is True

    # Next iteration
    profiler.step()
    assert profiler._profiling_for_iters == 2
    assert profiler._running is True

    # Iteration 2 (exceeds max)
    profiler.step()

    # Should have stopped now
    assert profiler._running is False
    assert profiler.stop_call_count == 1


131
def test_idempotency(default_profiler_config):
132
    """Test that calling start/stop multiple times doesn't break logic."""
133
    profiler = ConcreteWorkerProfiler(default_profiler_config)
134
135
136
137
138
139
140
141
142
143
144
145

    # Double Start
    profiler.start()
    profiler.start()
    assert profiler.start_call_count == 1  # Should only start once

    # Double Stop
    profiler.stop()
    profiler.stop()
    assert profiler.stop_call_count == 1  # Should only stop once


146
def test_step_inactive(default_profiler_config):
147
    """Test that stepping while inactive does nothing."""
148
149
    default_profiler_config.delay_iterations = 2
    profiler = ConcreteWorkerProfiler(default_profiler_config)
150
151
152
153
154
155
156
157
158

    # Not started yet
    profiler.step()
    profiler.step()

    # Even though we stepped 2 times, start shouldn't happen because active=False
    assert profiler.start_call_count == 0


159
def test_start_failure(default_profiler_config):
160
    """Test behavior when the underlying _start method raises exception."""
161
    profiler = ConcreteWorkerProfiler(default_profiler_config)
162
163
164
165
166
167
168
169
170
171
    profiler.should_fail_start = True

    profiler.start()

    # Exception caught in _call_start
    assert profiler._running is False  # Should not mark as running
    assert profiler._active is True  # Request is still considered active
    assert profiler.start_call_count == 0  # Logic failed inside start


172
def test_shutdown(default_profiler_config):
173
    """Test that shutdown calls stop only if running."""
174
    profiler = ConcreteWorkerProfiler(default_profiler_config)
175
176
177
178
179
180
181
182
183
184
185

    # Case 1: Not running
    profiler.shutdown()
    assert profiler.stop_call_count == 0

    # Case 2: Running
    profiler.start()
    profiler.shutdown()
    assert profiler.stop_call_count == 1


186
def test_mixed_delay_and_stop(default_profiler_config):
187
    """Test manual stop during the delay period."""
188
189
    default_profiler_config.delay_iterations = 5
    profiler = ConcreteWorkerProfiler(default_profiler_config)
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204

    profiler.start()
    profiler.step()
    profiler.step()

    # User cancels before delay finishes
    profiler.stop()
    assert profiler._active is False

    # Further steps should not trigger start
    profiler.step()
    profiler.step()
    profiler.step()

    assert profiler.start_call_count == 0