test_metrics.py 16 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import os
4
import time
5

6
import pytest
7
import ray
8
9
from prometheus_client import REGISTRY

10
import vllm.envs as envs
11
from vllm import EngineArgs, LLMEngine
12
from vllm.distributed import cleanup_dist_env_and_memory
13
14
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
15
from vllm.engine.metrics import RayPrometheusStatLogger
16
from vllm.sampling_params import SamplingParams
zhuwenwen's avatar
zhuwenwen committed
17
import vllm.envs as envs
18

19
from ..utils import models_path_prefix
20
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET
21

22
23
24
25
26
27
28

@pytest.fixture(scope="function", autouse=True)
def use_v0_only(monkeypatch):
    """
    This module tests V0 internals, so set VLLM_USE_V1=0.
    """
    monkeypatch.setenv('VLLM_USE_V1', '0')
29

30

31
MODELS = [
zhuwenwen's avatar
zhuwenwen committed
32
    os.path.join(models_path_prefix, "distilbert/distilgpt2"),
33
34
35
36
]


@pytest.mark.parametrize("model", MODELS)
zhuwenwen's avatar
zhuwenwen committed
37
@pytest.mark.parametrize("dtype", [("float" if envs.VLLM_USE_TRITON_FLASH_ATTN else "half")])
38
@pytest.mark.parametrize("max_tokens", [128])
39
def test_metric_counter_prompt_tokens(
40
41
42
43
44
45
    vllm_runner,
    example_prompts,
    model: str,
    dtype: str,
    max_tokens: int,
) -> None:
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
    with vllm_runner(model,
                     dtype=dtype,
                     disable_log_stats=False,
                     gpu_memory_utilization=0.4) as vllm_model:
        tokenizer = vllm_model.model.get_tokenizer()
        prompt_token_counts = [
            len(tokenizer.encode(p)) for p in example_prompts
        ]
        # This test needs at least 2 prompts in a batch of different lengths to
        # verify their token count is correct despite padding.
        assert len(example_prompts) > 1, "at least 2 prompts are required"
        assert prompt_token_counts[0] != prompt_token_counts[1], (
            "prompts of different lengths are required")
        vllm_prompt_token_count = sum(prompt_token_counts)

        _ = vllm_model.generate_greedy(example_prompts, max_tokens)
62
        stat_logger = vllm_model.model.llm_engine.stat_loggers['prometheus']
63
64
        metric_count = stat_logger.metrics.counter_prompt_tokens.labels(
            **stat_logger.labels)._value.get()
65
66

    assert vllm_prompt_token_count == metric_count, (
67
68
        f"prompt token count: {vllm_prompt_token_count!r}\n"
        f"metric: {metric_count!r}")
69
70
71


@pytest.mark.parametrize("model", MODELS)
zhuwenwen's avatar
zhuwenwen committed
72
@pytest.mark.parametrize("dtype", [("float" if envs.VLLM_USE_TRITON_FLASH_ATTN else "half")])
73
74
75
76
77
78
79
80
@pytest.mark.parametrize("max_tokens", [128])
def test_metric_counter_generation_tokens(
    vllm_runner,
    example_prompts,
    model: str,
    dtype: str,
    max_tokens: int,
) -> None:
81
82
83
84
85
86
    with vllm_runner(model,
                     dtype=dtype,
                     disable_log_stats=False,
                     gpu_memory_utilization=0.4) as vllm_model:
        vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
        tokenizer = vllm_model.model.get_tokenizer()
87
        stat_logger = vllm_model.model.llm_engine.stat_loggers['prometheus']
88
89
90
91
92
93
94
95
96
        metric_count = stat_logger.metrics.counter_generation_tokens.labels(
            **stat_logger.labels)._value.get()
        vllm_generation_count = 0
        for i in range(len(example_prompts)):
            vllm_output_ids, vllm_output_str = vllm_outputs[i]
            prompt_ids = tokenizer.encode(example_prompts[i])
            # vllm_output_ids contains both prompt tokens and generation tokens.
            # We're interested only in the count of the generation tokens.
            vllm_generation_count += len(vllm_output_ids) - len(prompt_ids)
97
98

    assert vllm_generation_count == metric_count, (
99
100
        f"generation token count: {vllm_generation_count!r}\n"
        f"metric: {metric_count!r}")
101
102


103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("max_tokens", [128, 129])
@pytest.mark.parametrize("disable_async_output_proc", [True, False])
def test_metric_counter_generation_tokens_multi_step(
    vllm_runner,
    example_prompts,
    model: str,
    max_tokens: int,
    disable_async_output_proc: bool,
) -> None:
    num_scheduler_steps = 8
    with vllm_runner(
            model,
            disable_log_stats=False,
            gpu_memory_utilization=0.4,
            num_scheduler_steps=num_scheduler_steps,
            disable_async_output_proc=disable_async_output_proc,
    ) as vllm_model:
        vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
        tokenizer = vllm_model.model.get_tokenizer()
        stat_logger = vllm_model.model.llm_engine.stat_loggers['prometheus']
        metric_count = stat_logger.metrics.counter_generation_tokens.labels(
            **stat_logger.labels)._value.get()
        vllm_generation_count = 0
        for i in range(len(example_prompts)):
            vllm_output_ids, vllm_output_str = vllm_outputs[i]
            prompt_ids = tokenizer.encode(example_prompts[i])
            # vllm_output_ids contains both prompt tokens and generation tokens.
            # We're interested only in the count of the generation tokens.
            vllm_generation_count += len(vllm_output_ids) - len(prompt_ids)

    # The multi-step scheduling will continue to execute forward even when
    # encountering EOS, leading to slightly imprecise metrics.
    assert abs(vllm_generation_count - metric_count) <\
        len(example_prompts) * num_scheduler_steps, \
        (f"generation token count: {vllm_generation_count!r}\n"
         f"metric: {metric_count!r}")


142
@pytest.mark.parametrize("model", MODELS)
zhuwenwen's avatar
zhuwenwen committed
143
@pytest.mark.parametrize("dtype", [("float" if envs.VLLM_USE_TRITON_FLASH_ATTN else "half")])
144
145
146
147
@pytest.mark.parametrize(
    "served_model_name",
    [None, [], ["ModelName0"], ["ModelName0", "ModelName1", "ModelName2"]])
def test_metric_set_tag_model_name(vllm_runner, model: str, dtype: str,
148
                                   served_model_name: list[str]) -> None:
149
150
151
152
153
    with vllm_runner(model,
                     dtype=dtype,
                     disable_log_stats=False,
                     gpu_memory_utilization=0.3,
                     served_model_name=served_model_name) as vllm_model:
154
        stat_logger = vllm_model.model.llm_engine.stat_loggers['prometheus']
155
        metrics_tag_content = stat_logger.labels["model_name"]
156

157
158
    if envs.VLLM_CI_USE_S3:
        model = f"{MODEL_WEIGHTS_S3_BUCKET}/{model}"
159
    if served_model_name is None or served_model_name == []:
160
        assert metrics_tag_content == model, (
161
            f"Metrics tag model_name is wrong! expect: {model!r}\n"
162
163
164
165
166
167
168
169
            f"actual: {metrics_tag_content!r}")
    else:
        assert metrics_tag_content == served_model_name[0], (
            f"Metrics tag model_name is wrong! expect: "
            f"{served_model_name[0]!r}\n"
            f"actual: {metrics_tag_content!r}")


170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [4])
@pytest.mark.parametrize("disable_log_stats", [True, False])
@pytest.mark.asyncio
async def test_async_engine_log_metrics_regression(
    example_prompts,
    model: str,
    dtype: str,
    max_tokens: int,
    disable_log_stats: bool,
) -> None:
    """
    Regression test ensuring async engine generates metrics
    when disable_log_stats=False
    (see: https://github.com/vllm-project/vllm/pull/4150#pullrequestreview-2008176678)
    """
187
188
189
190
191
    engine_args = AsyncEngineArgs(
        model=model,
        dtype=dtype,
        disable_log_stats=disable_log_stats,
    )
192
193
194
195
196
197
198
199
200
201
202
    async_engine = AsyncLLMEngine.from_engine_args(engine_args)
    for i, prompt in enumerate(example_prompts):
        results = async_engine.generate(
            prompt,
            SamplingParams(max_tokens=max_tokens),
            f"request-id-{i}",
        )
        # Exhaust the async iterator to make the async engine work
        async for _ in results:
            pass

203
    assert_metrics(model, async_engine.engine, disable_log_stats,
204
205
206
207
208
209
210
211
212
213
214
215
216
217
                   len(example_prompts))


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [4])
@pytest.mark.parametrize("disable_log_stats", [True, False])
def test_engine_log_metrics_regression(
    example_prompts,
    model: str,
    dtype: str,
    max_tokens: int,
    disable_log_stats: bool,
) -> None:
218
219
220
221
222
    engine_args = EngineArgs(
        model=model,
        dtype=dtype,
        disable_log_stats=disable_log_stats,
    )
223
224
225
226
227
228
229
230
231
232
    engine = LLMEngine.from_engine_args(engine_args)
    for i, prompt in enumerate(example_prompts):
        engine.add_request(
            f"request-id-{i}",
            prompt,
            SamplingParams(max_tokens=max_tokens),
        )
    while engine.has_unfinished_requests():
        engine.step()

233
234
235
    if envs.VLLM_CI_USE_S3:
        model = f"{MODEL_WEIGHTS_S3_BUCKET}/{model}"
    assert_metrics(model, engine, disable_log_stats, len(example_prompts))
236
237


238
239
240
241
242
243
244
245
246
247
248
249
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [10])
def test_metric_spec_decode(
    vllm_runner,
    example_prompts,
    model: str,
    dtype: str,
    max_tokens: int,
) -> None:
    k = 5

250
251
252
253
254
    with vllm_runner(
            model,
            dtype=dtype,
            disable_log_stats=False,
            gpu_memory_utilization=0.4,
255
256
257
258
            speculative_config={
                "model": model,
                "num_speculative_tokens": k,
            },
259
    ) as vllm_model:
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289

        # Force log interval to be 0 to catch all metrics.
        stat_logger = vllm_model.model.llm_engine.stat_loggers['prometheus']
        stat_logger.local_interval = 0

        # Note that the purpose of this test is to verify spec decode
        # metrics instead of functional correctness, so the expected values
        # are intended to be loose.
        metric_name_to_expected_fn = {
            "gauge_spec_decode_draft_acceptance_rate": lambda v: 0 <= v <= 1,
            "gauge_spec_decode_efficiency": lambda v: 0 <= v <= 1,
            "counter_spec_decode_num_accepted_tokens": lambda v: 0 <= v <= k,
            "counter_spec_decode_num_draft_tokens": lambda v: v == k,
            "counter_spec_decode_num_emitted_tokens":
            lambda v: 0 <= v <= k + 1,
        }

        # Use one request to better inspect the metrics.
        prompts = example_prompts[:1]

        _ = vllm_model.generate_greedy(prompts, max_tokens)
        for metric_name, is_expected in metric_name_to_expected_fn.items():
            metric_val = getattr(
                stat_logger.metrics,
                metric_name).labels(**stat_logger.labels)._value.get()
            assert is_expected(metric_val), (
                f"the value of metric {metric_name} ({metric_val}) "
                "does not meet expectation")


290
291
292
293
294
295
296
297
298
299
300
301
302
303
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [10])
@pytest.mark.parametrize("log_interval", [1, 3, 5, 7])
def test_metric_spec_decode_interval(
    vllm_runner,
    example_prompts,
    model: str,
    dtype: str,
    max_tokens: int,
    log_interval: int,
) -> None:
    k = 5

304
305
306
307
308
    engine_args = EngineArgs(
        model=model,
        dtype=dtype,
        disable_log_stats=False,
        gpu_memory_utilization=0.4,
309
310
311
312
        speculative_config={
            "model": model,
            "num_speculative_tokens": k,
        },
313
314
        enforce_eager=True,
    )
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377

    engine = LLMEngine.from_engine_args(engine_args)

    try:

        engine.add_request(
            "request-id-0",
            example_prompts[0],
            SamplingParams(max_tokens=max_tokens),
        )

        # set log internal
        stat_logger = engine.stat_loggers['prometheus']
        stat_logger.local_interval = log_interval

        # prefill
        engine.step()

        # wait for 5 seconds to ensure that spec decode metrics
        # get triggered in first decode step
        time.sleep(5)

        # first decode step should trigger async collection of metrics
        engine.step()

        # wait one second to allow H2D transfer to finish
        time.sleep(1)

        # second decode step should now be able to collect the spec
        # decode stats and the request should also be finished
        engine.step()

        # must have finisehd now
        assert not engine.has_unfinished_requests()

        # wait to ensure logging occurs
        time.sleep(log_interval)

        # force logging
        engine.step()

        # Note that the purpose of this test is to verify spec decode
        # metrics instead of functional correctness, so the expected values
        # are intended to be loose.
        metric_name_to_expected_fn = {
            "gauge_spec_decode_draft_acceptance_rate": lambda v: 0 <= v <= 1,
            "gauge_spec_decode_efficiency": lambda v: 0 <= v <= 1,
            "counter_spec_decode_num_accepted_tokens": lambda v: 0 <= v <= k,
            "counter_spec_decode_num_draft_tokens": lambda v: v == k,
            "counter_spec_decode_num_emitted_tokens":
            lambda v: 0 <= v <= k + 1,
        }

        for metric_name, is_expected in metric_name_to_expected_fn.items():
            metric_val = getattr(
                stat_logger.metrics,
                metric_name).labels(**stat_logger.labels)._value.get()
            assert is_expected(metric_val), (
                f"the value of metric {metric_name} ({metric_val}) "
                "does not meet expectation")

    finally:
        del engine
378
        cleanup_dist_env_and_memory()
379
380


381
def assert_metrics(model: str, engine: LLMEngine, disable_log_stats: bool,
382
383
384
                   num_requests: int) -> None:
    if disable_log_stats:
        with pytest.raises(AttributeError):
385
            _ = engine.stat_loggers
386
    else:
387
388
        assert (engine.stat_loggers
                is not None), "engine.stat_loggers should be set"
389
390
391
        # Ensure the count bucket of request-level histogram metrics matches
        # the number of requests as a simple sanity check to ensure metrics are
        # generated
392
        labels = {'model_name': model}
393
394
395
396
397
        request_histogram_metrics = [
            "vllm:e2e_request_latency_seconds",
            "vllm:request_prompt_tokens",
            "vllm:request_generation_tokens",
            "vllm:request_params_n",
398
            "vllm:request_params_max_tokens",
399
400
401
402
403
404
        ]
        for metric_name in request_histogram_metrics:
            metric_value = REGISTRY.get_sample_value(f"{metric_name}_count",
                                                     labels)
            assert (
                metric_value == num_requests), "Metrics should be collected"
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [16])
def test_engine_log_metrics_ray(
    example_prompts,
    model: str,
    dtype: str,
    max_tokens: int,
) -> None:
    # This test is quite weak - it only checks that we can use
    # RayPrometheusStatLogger without exceptions.
    # Checking whether the metrics are actually emitted is unfortunately
    # non-trivial.

    # We have to run in a Ray task for Ray metrics to be emitted correctly
    @ray.remote(num_gpus=1)
    def _inner():

        class _RayPrometheusStatLogger(RayPrometheusStatLogger):

            def __init__(self, *args, **kwargs):
                self._i = 0
                super().__init__(*args, **kwargs)

            def log(self, *args, **kwargs):
                self._i += 1
                return super().log(*args, **kwargs)

        engine_args = EngineArgs(
            model=model,
            dtype=dtype,
            disable_log_stats=False,
        )
        engine = LLMEngine.from_engine_args(engine_args)
        logger = _RayPrometheusStatLogger(
            local_interval=0.5,
            labels=dict(model_name=engine.model_config.served_model_name),
444
            vllm_config=engine.vllm_config)
445
446
447
448
449
450
451
452
453
454
455
456
        engine.add_logger("ray", logger)
        for i, prompt in enumerate(example_prompts):
            engine.add_request(
                f"request-id-{i}",
                prompt,
                SamplingParams(max_tokens=max_tokens),
            )
        while engine.has_unfinished_requests():
            engine.step()
        assert logger._i > 0, ".log must be called at least once"

    ray.get(_inner.remote())