Unverified Commit ec68d53b authored by Yan Ma's avatar Yan Ma Committed by GitHub
Browse files

Add platform manual_seed_all API (#38468)


Signed-off-by: default avatarYan Ma <yan.ma@intel.com>
parent 13e6b1b9
...@@ -9,6 +9,7 @@ from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( ...@@ -9,6 +9,7 @@ from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size, moe_align_block_size,
) )
from vllm.triton_utils import triton from vllm.triton_utils import triton
from vllm.utils.torch_utils import set_random_seed
def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor: def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor:
...@@ -44,7 +45,7 @@ configs = list( ...@@ -44,7 +45,7 @@ configs = list(
def benchmark(num_tokens, num_experts, topk, ep_size, provider): def benchmark(num_tokens, num_experts, topk, ep_size, provider):
"""Benchmark function for Triton.""" """Benchmark function for Triton."""
block_size = 256 block_size = 256
torch.cuda.manual_seed_all(0) set_random_seed(0)
topk_ids = get_topk_ids(num_tokens, num_experts, topk) topk_ids = get_topk_ids(num_tokens, num_experts, topk)
e_map = None e_map = None
......
...@@ -16,6 +16,7 @@ from vllm.utils.deep_gemm import ( ...@@ -16,6 +16,7 @@ from vllm.utils.deep_gemm import (
fp8_gemm_nt, fp8_gemm_nt,
per_block_cast_to_fp8, per_block_cast_to_fp8,
) )
from vllm.utils.torch_utils import set_random_seed
def benchmark_shape( def benchmark_shape(
...@@ -235,9 +236,7 @@ def run_benchmarks(verbose: bool = False): ...@@ -235,9 +236,7 @@ def run_benchmarks(verbose: bool = False):
torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True
# Set seeds for reproducibility # Set seeds for reproducibility
torch.manual_seed(42) set_random_seed(42)
torch.cuda.manual_seed(42)
# Define benchmark shapes (m, n, k) # Define benchmark shapes (m, n, k)
shapes = [ shapes = [
(8, 4096, 7168), (8, 4096, 7168),
......
...@@ -122,8 +122,6 @@ def test_linear_decode_forward_triton( ...@@ -122,8 +122,6 @@ def test_linear_decode_forward_triton(
dtype: torch.dtype, dtype: torch.dtype,
): ):
torch.set_default_device("cuda") torch.set_default_device("cuda")
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
set_random_seed(42) set_random_seed(42)
base = 0.01 base = 0.01
q = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype) q = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
...@@ -165,8 +163,6 @@ def test_linear_decode_forward_triton_with_padding( ...@@ -165,8 +163,6 @@ def test_linear_decode_forward_triton_with_padding(
dtype: torch.dtype, dtype: torch.dtype,
): ):
torch.set_default_device("cuda") torch.set_default_device("cuda")
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
set_random_seed(42) set_random_seed(42)
batch_size = 4 batch_size = 4
...@@ -229,8 +225,6 @@ def test_lightning_attention_reference( ...@@ -229,8 +225,6 @@ def test_lightning_attention_reference(
dtype: torch.dtype, dtype: torch.dtype,
): ):
torch.set_default_device("cuda") torch.set_default_device("cuda")
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
set_random_seed(42) set_random_seed(42)
base = 0.01 base = 0.01
......
...@@ -17,6 +17,7 @@ from vllm.model_executor.layers.quantization.utils.int8_utils import ( ...@@ -17,6 +17,7 @@ from vllm.model_executor.layers.quantization.utils.int8_utils import (
per_token_group_quant_int8, per_token_group_quant_int8,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed
DTYPES = [torch.bfloat16, torch.float] DTYPES = [torch.bfloat16, torch.float]
QUANT_DTYPES = [torch.int8, current_platform.fp8_dtype()] QUANT_DTYPES = [torch.int8, current_platform.fp8_dtype()]
...@@ -180,9 +181,7 @@ def test_rms_norm( ...@@ -180,9 +181,7 @@ def test_rms_norm(
device: str, device: str,
strided_input: bool, strided_input: bool,
) -> None: ) -> None:
torch.random.manual_seed(seed) set_random_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
torch.accelerator.set_device_index(device) torch.accelerator.set_device_index(device)
......
...@@ -16,6 +16,7 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_mxint4_moe import ...@@ -16,6 +16,7 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_mxint4_moe import
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
from vllm.utils.torch_utils import set_random_seed
def mxint4_quantize( def mxint4_quantize(
...@@ -134,7 +135,7 @@ def test_marlin_vs_trtllm_mxint4_moe_kimik2(monkeypatch, m, n, k, e, topk, group ...@@ -134,7 +135,7 @@ def test_marlin_vs_trtllm_mxint4_moe_kimik2(monkeypatch, m, n, k, e, topk, group
pytest.importorskip("flashinfer") pytest.importorskip("flashinfer")
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_INT4", "1") monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_INT4", "1")
torch.cuda.manual_seed(0) set_random_seed(0)
dtype = torch.bfloat16 dtype = torch.bfloat16
...@@ -289,7 +290,7 @@ def test_flashinfer_trtllm_mxint4_moe_wrapper(m, n, k, e, topk): ...@@ -289,7 +290,7 @@ def test_flashinfer_trtllm_mxint4_moe_wrapper(m, n, k, e, topk):
flashinfer_trtllm_mxint4_moe, flashinfer_trtllm_mxint4_moe,
) )
torch.cuda.manual_seed(0) set_random_seed(0)
dtype = torch.bfloat16 dtype = torch.bfloat16
a = torch.randn((m, k), device="cuda", dtype=dtype) * 0.5 a = torch.randn((m, k), device="cuda", dtype=dtype) * 0.5
......
...@@ -1031,7 +1031,7 @@ def test_fused_marlin_moe( ...@@ -1031,7 +1031,7 @@ def test_fused_marlin_moe(
act_order: bool, act_order: bool,
is_k_full: bool, is_k_full: bool,
): ):
torch.cuda.manual_seed(1) set_random_seed(1)
group_size = group_blocks if group_blocks <= 0 else group_blocks * 16 group_size = group_blocks if group_blocks <= 0 else group_blocks * 16
if c_type == scalar_types.float16: if c_type == scalar_types.float16:
...@@ -1131,7 +1131,7 @@ def test_fused_marlin_moe( ...@@ -1131,7 +1131,7 @@ def test_fused_marlin_moe(
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") @pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
@pytest.mark.parametrize("m", [1, 256]) @pytest.mark.parametrize("m", [1, 256])
def test_fused_marlin_moe_with_bias(m): def test_fused_marlin_moe_with_bias(m):
torch.cuda.manual_seed(0) set_random_seed(0)
e, topk = 32, 4 e, topk = 32, 4
n, k = 2048, 2048 n, k = 2048, 2048
...@@ -1213,7 +1213,7 @@ def test_fused_marlin_moe_non_gated( ...@@ -1213,7 +1213,7 @@ def test_fused_marlin_moe_non_gated(
Non-gated activations like relu2 don't have the gate-up projection pattern, Non-gated activations like relu2 don't have the gate-up projection pattern,
so w1 has shape (e, n, k) instead of (e, 2*n, k). so w1 has shape (e, n, k) instead of (e, 2*n, k).
""" """
torch.cuda.manual_seed(42) set_random_seed(42)
group_size = 16 # NVFP4 group size group_size = 16 # NVFP4 group size
is_k_full = True is_k_full = True
...@@ -1397,7 +1397,7 @@ def test_cpu_fused_moe_basic( ...@@ -1397,7 +1397,7 @@ def test_cpu_fused_moe_basic(
from vllm.model_executor.layers.fused_moe.cpu_fused_moe import CPUFusedMOE from vllm.model_executor.layers.fused_moe.cpu_fused_moe import CPUFusedMOE
device = "cpu" device = "cpu"
torch.manual_seed(7) set_random_seed(7)
a = torch.randn((m, k), device=device, dtype=dtype) / 10 a = torch.randn((m, k), device=device, dtype=dtype) / 10
w13 = torch.randn((e, 2 * n, k), device=device, dtype=dtype) / 10 w13 = torch.randn((e, 2 * n, k), device=device, dtype=dtype) / 10
...@@ -1469,7 +1469,7 @@ def test_batched_fused_marlin_moe( ...@@ -1469,7 +1469,7 @@ def test_batched_fused_marlin_moe(
f"topk={topk}, " f"topk={topk}, "
f"max_tokens_per_batch={max_tokens_per_batch}" f"max_tokens_per_batch={max_tokens_per_batch}"
) )
torch.cuda.manual_seed(0) set_random_seed(0)
dtype = torch.bfloat16 dtype = torch.bfloat16
quant_dtype = scalar_types.float4_e2m1f quant_dtype = scalar_types.float4_e2m1f
......
...@@ -15,7 +15,7 @@ from vllm.config import VllmConfig, set_current_vllm_config ...@@ -15,7 +15,7 @@ from vllm.config import VllmConfig, set_current_vllm_config
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.torch_utils import is_torch_equal_or_newer from vllm.utils.torch_utils import is_torch_equal_or_newer, set_random_seed
class SimpleLinear(nn.Module): class SimpleLinear(nn.Module):
...@@ -144,8 +144,7 @@ def test_routed_input_transform_inside_vs_outside( ...@@ -144,8 +144,7 @@ def test_routed_input_transform_inside_vs_outside(
rocm_aiter_ops.refresh_env_variables() rocm_aiter_ops.refresh_env_variables()
torch.manual_seed(42) set_random_seed(42)
torch.cuda.manual_seed(42)
num_experts = 8 num_experts = 8
top_k = 2 top_k = 2
......
...@@ -7,6 +7,7 @@ import vllm._custom_ops as ops ...@@ -7,6 +7,7 @@ import vllm._custom_ops as ops
from tests.kernels.utils import opcheck from tests.kernels.utils import opcheck
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed
DTYPES = [torch.bfloat16, torch.float16] DTYPES = [torch.bfloat16, torch.float16]
QUANT_DTYPES = [current_platform.fp8_dtype()] QUANT_DTYPES = [current_platform.fp8_dtype()]
...@@ -49,9 +50,7 @@ def test_silu_and_mul( ...@@ -49,9 +50,7 @@ def test_silu_and_mul(
seed: int, seed: int,
device: str, device: str,
) -> None: ) -> None:
torch.random.manual_seed(seed) set_random_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
layer = SiluAndMul() layer = SiluAndMul()
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import random
import numpy as np
import pytest import pytest
import torch import torch
from transformers import AutoModelForTokenClassification from transformers import AutoModelForTokenClassification
from tests.models.utils import softmax from tests.models.utils import softmax
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def seed_everything(): def seed_everything():
"""Seed all random number generators for reproducibility.""" """Seed all random number generators for reproducibility."""
seed = 0 seed = 0
random.seed(seed) set_random_seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False torch.backends.cudnn.benchmark = False
yield yield
......
...@@ -5,9 +5,6 @@ This script contains: ...@@ -5,9 +5,6 @@ This script contains:
1. test lora with speculative decoding for batch inference 1. test lora with speculative decoding for batch inference
""" """
import random
import numpy as np
import pytest import pytest
import torch import torch
...@@ -15,6 +12,7 @@ from vllm import LLM, SamplingParams ...@@ -15,6 +12,7 @@ from vllm import LLM, SamplingParams
from vllm.distributed import cleanup_dist_env_and_memory from vllm.distributed import cleanup_dist_env_and_memory
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed
LORA_TEST_PROMPT_MAP: dict[str, str] = {} LORA_TEST_PROMPT_MAP: dict[str, str] = {}
...@@ -63,10 +61,7 @@ def test_batch_inference_correctness( ...@@ -63,10 +61,7 @@ def test_batch_inference_correctness(
with monkeypatch.context() as m: with monkeypatch.context() as m:
# Disable randomness # Disable randomness
m.setenv("CUBLAS_WORKSPACE_CONFIG", ":4096:8") m.setenv("CUBLAS_WORKSPACE_CONFIG", ":4096:8")
torch.manual_seed(SEED) set_random_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.benchmark = False torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
......
...@@ -14,6 +14,7 @@ from tests.v1.attention.utils import ( ...@@ -14,6 +14,7 @@ from tests.v1.attention.utils import (
) )
from vllm.config import ParallelConfig, SpeculativeConfig from vllm.config import ParallelConfig, SpeculativeConfig
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed
from vllm.v1.attention.backend import CommonAttentionMetadata from vllm.v1.attention.backend import CommonAttentionMetadata
from vllm.v1.attention.backends.fa_utils import is_flash_attn_varlen_func_available from vllm.v1.attention.backends.fa_utils import is_flash_attn_varlen_func_available
from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.v1.attention.backends.registry import AttentionBackendEnum
...@@ -323,8 +324,7 @@ def forward_attention( ...@@ -323,8 +324,7 @@ def forward_attention(
def test_tree_attn_correctness( def test_tree_attn_correctness(
reference_backend: AttentionBackendEnum, reference_backend: AttentionBackendEnum,
) -> None: ) -> None:
torch.manual_seed(42) set_random_seed(42)
torch.cuda.manual_seed_all(42)
device = "cuda" device = "cuda"
tree_attn_masks = { tree_attn_masks = {
......
...@@ -9,6 +9,7 @@ import regex as re ...@@ -9,6 +9,7 @@ import regex as re
# --------------------------------------------------------------------------- # # --------------------------------------------------------------------------- #
_TORCH_CUDA_PATTERNS = [ _TORCH_CUDA_PATTERNS = [
r"\btorch\.cuda\.(empty_cache|synchronize|device_count|current_device|memory_reserved|memory_allocated|max_memory_allocated|max_memory_reserved|reset_peak_memory_stats|memory_stats|set_device|device\()\b", r"\btorch\.cuda\.(empty_cache|synchronize|device_count|current_device|memory_reserved|memory_allocated|max_memory_allocated|max_memory_reserved|reset_peak_memory_stats|memory_stats|set_device|device\()\b",
r"\btorch\.cuda\.(manual_seed|manual_seed_all)\b",
r"\bwith\storch\.cuda\.device\b", r"\bwith\storch\.cuda\.device\b",
# Calls torch.cuda.{_is_compiled/_device_count_amdsmi/_device_count_nvml} internally # Calls torch.cuda.{_is_compiled/_device_count_amdsmi/_device_count_nvml} internally
r"\bcuda_device_count_stateless\(\)\b", r"\bcuda_device_count_stateless\(\)\b",
...@@ -24,6 +25,14 @@ def scan_file(path: str) -> int: ...@@ -24,6 +25,14 @@ def scan_file(path: str) -> int:
for match in re.finditer(pattern, content, re.MULTILINE): for match in re.finditer(pattern, content, re.MULTILINE):
# Calculate line number from match position # Calculate line number from match position
line_num = content[: match.start() + 1].count("\n") + 1 line_num = content[: match.start() + 1].count("\n") + 1
matched_text = match.group(0)
if "manual_seed" in matched_text:
print(
f"{path}:{line_num}: "
"\033[91merror:\033[0m "
f"Found {matched_text} API call. Use set_random_seed instead."
)
return 1
print( print(
f"{path}:{line_num}: " f"{path}:{line_num}: "
"\033[91merror:\033[0m " # red color "\033[91merror:\033[0m " # red color
......
...@@ -154,6 +154,10 @@ class CpuPlatform(Platform): ...@@ -154,6 +154,10 @@ class CpuPlatform(Platform):
""" """
torch.cpu.set_device(device) torch.cpu.set_device(device)
@classmethod
def manual_seed_all(cls, seed: int) -> None:
pass
@classmethod @classmethod
def inference_mode(cls): def inference_mode(cls):
return torch.no_grad() return torch.no_grad()
......
...@@ -188,6 +188,10 @@ class CudaPlatformBase(Platform): ...@@ -188,6 +188,10 @@ class CudaPlatformBase(Platform):
# for why and when it is needed # for why and when it is needed
_ = torch.zeros(1, device=device) _ = torch.zeros(1, device=device)
@classmethod
def manual_seed_all(cls, seed: int) -> None:
torch.cuda.manual_seed_all(seed)
@classmethod @classmethod
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None: def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None:
raise NotImplementedError raise NotImplementedError
......
...@@ -391,6 +391,11 @@ class Platform: ...@@ -391,6 +391,11 @@ class Platform:
""" """
raise NotImplementedError raise NotImplementedError
@classmethod
def manual_seed_all(cls, seed: int) -> None:
"""Set RNG seed across all devices for the current platform."""
raise NotImplementedError
@classmethod @classmethod
def pre_register_and_update( def pre_register_and_update(
cls, parser: FlexibleArgumentParser | None = None cls, parser: FlexibleArgumentParser | None = None
......
...@@ -605,6 +605,10 @@ class RocmPlatform(Platform): ...@@ -605,6 +605,10 @@ class RocmPlatform(Platform):
""" """
torch.cuda.set_device(device) torch.cuda.set_device(device)
@classmethod
def manual_seed_all(cls, seed: int) -> None:
torch.cuda.manual_seed_all(seed)
@classmethod @classmethod
@lru_cache(maxsize=8) @lru_cache(maxsize=8)
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None: def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None:
......
...@@ -125,6 +125,10 @@ class XPUPlatform(Platform): ...@@ -125,6 +125,10 @@ class XPUPlatform(Platform):
""" """
torch.xpu.set_device(device) torch.xpu.set_device(device)
@classmethod
def manual_seed_all(cls, seed: int) -> None:
torch.xpu.manual_seed_all(seed)
@classmethod @classmethod
def get_device_capability( def get_device_capability(
cls, cls,
......
...@@ -365,8 +365,9 @@ def set_random_seed(seed: int | None) -> None: ...@@ -365,8 +365,9 @@ def set_random_seed(seed: int | None) -> None:
random.seed(seed) random.seed(seed)
np.random.seed(seed) np.random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
if torch.cuda.is_available(): from vllm.platforms import current_platform
torch.cuda.manual_seed_all(seed)
current_platform.manual_seed_all(seed)
def create_kv_caches_with_random_flash( def create_kv_caches_with_random_flash(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment