Unverified Commit 0ccecf88 authored by 7. Sun's avatar 7. Sun Committed by GitHub
Browse files

[Tests] Standardize RNG seed utility across test files (#32982)


Signed-off-by: default avatar7. Sun <jhao.sun@gmail.com>
parent 0b9a735e
...@@ -2,13 +2,11 @@ ...@@ -2,13 +2,11 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Integration tests for FlexAttention backend vs default backend""" """Integration tests for FlexAttention backend vs default backend"""
import random
import numpy as np
import pytest import pytest
import torch import torch
from packaging import version from packaging import version
from tests.utils import set_random_seed
from tests.v1.attention.utils import ( from tests.v1.attention.utils import (
BatchSpec, BatchSpec,
create_common_attn_metadata, create_common_attn_metadata,
...@@ -27,15 +25,6 @@ MINIMUM_TORCH_VERSION = version.parse("2.7.0") ...@@ -27,15 +25,6 @@ MINIMUM_TORCH_VERSION = version.parse("2.7.0")
DIRECT_BUILD_VERSION = version.parse("2.9.dev0") DIRECT_BUILD_VERSION = version.parse("2.9.dev0")
def set_seed(seed):
"""Set seeds for reproducibility"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
@pytest.mark.skipif( @pytest.mark.skipif(
not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION, not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION,
reason="CUDA not available or PyTorch version < 2.7", reason="CUDA not available or PyTorch version < 2.7",
...@@ -57,7 +46,7 @@ def test_flex_attention_vs_default_backend(vllm_runner): ...@@ -57,7 +46,7 @@ def test_flex_attention_vs_default_backend(vllm_runner):
] ]
# Run with flex attention # Run with flex attention
set_seed(seed) set_random_seed(seed)
with vllm_runner( with vllm_runner(
model_name, model_name,
runner="generate", runner="generate",
...@@ -71,7 +60,7 @@ def test_flex_attention_vs_default_backend(vllm_runner): ...@@ -71,7 +60,7 @@ def test_flex_attention_vs_default_backend(vllm_runner):
) )
# Run with default backend # Run with default backend
set_seed(seed) set_random_seed(seed)
with vllm_runner( with vllm_runner(
model_name, model_name,
runner="generate", runner="generate",
......
...@@ -59,7 +59,10 @@ from vllm.tokenizers import get_tokenizer ...@@ -59,7 +59,10 @@ from vllm.tokenizers import get_tokenizer
from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.mem_constants import GB_bytes from vllm.utils.mem_constants import GB_bytes
from vllm.utils.network_utils import get_open_port from vllm.utils.network_utils import get_open_port
from vllm.utils.torch_utils import cuda_device_count_stateless from vllm.utils.torch_utils import (
cuda_device_count_stateless,
set_random_seed, # noqa: F401 - re-exported for use in test files
)
FP8_DTYPE = current_platform.fp8_dtype() FP8_DTYPE = current_platform.fp8_dtype()
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import random
from typing import Any from typing import Any
import pytest import pytest
from tests.utils import create_new_process_for_each_test from tests.utils import create_new_process_for_each_test, set_random_seed
from tests.v1.logits_processors.utils import ( from tests.v1.logits_processors.utils import (
DUMMY_LOGITPROC_ARG, DUMMY_LOGITPROC_ARG,
DUMMY_LOGITPROC_FQCN, DUMMY_LOGITPROC_FQCN,
...@@ -135,7 +134,7 @@ def test_custom_logitsprocs(monkeypatch, logitproc_source: CustomLogitprocSource ...@@ -135,7 +134,7 @@ def test_custom_logitsprocs(monkeypatch, logitproc_source: CustomLogitprocSource
# Test that logitproc info is passed to workers # Test that logitproc info is passed to workers
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1") monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1")
random.seed(40) set_random_seed(40)
# Choose LLM args based on logitproc source # Choose LLM args based on logitproc source
if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_NONE: if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_NONE:
...@@ -194,7 +193,7 @@ def test_custom_logitsprocs_req(monkeypatch): ...@@ -194,7 +193,7 @@ def test_custom_logitsprocs_req(monkeypatch):
# Test that logitproc info is passed to workers # Test that logitproc info is passed to workers
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1") monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1")
random.seed(40) set_random_seed(40)
_run_test( _run_test(
{"logits_processors": [WrappedPerReqLogitsProcessor]}, logitproc_loaded=True {"logits_processors": [WrappedPerReqLogitsProcessor]}, logitproc_loaded=True
) )
...@@ -237,7 +236,7 @@ def test_rejects_custom_logitsprocs( ...@@ -237,7 +236,7 @@ def test_rejects_custom_logitsprocs(
logitproc from logitproc from
""" """
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
random.seed(40) set_random_seed(40)
test_params: dict[str, dict[str, Any]] = { test_params: dict[str, dict[str, Any]] = {
"pooling": { "pooling": {
......
...@@ -333,6 +333,8 @@ def set_random_seed(seed: int | None) -> None: ...@@ -333,6 +333,8 @@ def set_random_seed(seed: int | None) -> None:
random.seed(seed) random.seed(seed)
np.random.seed(seed) np.random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def create_kv_caches_with_random_flash( def create_kv_caches_with_random_flash(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment