Unverified Commit 622b7ab9 authored by wangshuai09's avatar wangshuai09 Committed by GitHub
Browse files

[Hardware] using current_platform.seed_everything (#9785)


Signed-off-by: default avatarwangshuai09 <391746016@qq.com>
parent 09500f7d
...@@ -9,7 +9,8 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask ...@@ -9,7 +9,8 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
from vllm.attention.backends.xformers import _make_alibi_bias from vllm.attention.backends.xformers import _make_alibi_bias
from vllm.attention.ops.prefix_prefill import context_attention_fwd from vllm.attention.ops.prefix_prefill import context_attention_fwd
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, seed_everything from vllm.platforms import current_platform
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
NUM_HEADS = [64] NUM_HEADS = [64]
NUM_QUERIES_PER_KV = [1, 8, 64] NUM_QUERIES_PER_KV = [1, 8, 64]
...@@ -39,7 +40,7 @@ def test_contexted_kv_attention( ...@@ -39,7 +40,7 @@ def test_contexted_kv_attention(
kv_cache_dtype: str, kv_cache_dtype: str,
device: str, device: str,
) -> None: ) -> None:
seed_everything(0) current_platform.seed_everything(0)
torch.set_default_device(device) torch.set_default_device(device)
# Need this, otherwise when we capture the graph the process # Need this, otherwise when we capture the graph the process
...@@ -234,7 +235,7 @@ def test_contexted_kv_attention_alibi( ...@@ -234,7 +235,7 @@ def test_contexted_kv_attention_alibi(
kv_cache_dtype: str, kv_cache_dtype: str,
device: str, device: str,
) -> None: ) -> None:
seed_everything(0) current_platform.seed_everything(0)
torch.set_default_device(device) torch.set_default_device(device)
# Need this, otherwise when we capture the graph the process # Need this, otherwise when we capture the graph the process
......
...@@ -39,7 +39,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope ...@@ -39,7 +39,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding, get_masked_input_and_mask) ParallelLMHead, VocabParallelEmbedding, get_masked_input_and_mask)
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
from vllm.utils import seed_everything from vllm.platforms import current_platform
from .utils import DummyLoRAManager from .utils import DummyLoRAManager
...@@ -923,7 +923,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device, ...@@ -923,7 +923,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
seq_len) -> None: seq_len) -> None:
dtype = torch.float16 dtype = torch.float16
seed = 0 seed = 0
seed_everything(seed) current_platform.seed_everything(seed)
torch.set_default_device(device) torch.set_default_device(device)
punica_wrapper = PunicaWrapper(8192, 256, device) punica_wrapper = PunicaWrapper(8192, 256, device)
max_loras = 8 max_loras = 8
......
...@@ -15,8 +15,8 @@ from vllm.lora.ops.bgmv_shrink import bgmv_shrink ...@@ -15,8 +15,8 @@ from vllm.lora.ops.bgmv_shrink import bgmv_shrink
from vllm.lora.ops.sgmv_expand import sgmv_expand from vllm.lora.ops.sgmv_expand import sgmv_expand
from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice
from vllm.lora.ops.sgmv_shrink import sgmv_shrink from vllm.lora.ops.sgmv_shrink import sgmv_shrink
from vllm.platforms import current_platform
from vllm.triton_utils.libentry import LibEntry from vllm.triton_utils.libentry import LibEntry
from vllm.utils import seed_everything
from .utils import (generate_data, generate_data_for_expand_nslices, from .utils import (generate_data, generate_data_for_expand_nslices,
ref_torch_groupgemm) ref_torch_groupgemm)
...@@ -146,7 +146,7 @@ def test_punica_sgmv( ...@@ -146,7 +146,7 @@ def test_punica_sgmv(
device: str, device: str,
): ):
torch.set_default_device(device) torch.set_default_device(device)
seed_everything(seed) current_platform.seed_everything(seed)
seq_length = 128 seq_length = 128
( (
...@@ -239,7 +239,7 @@ def test_punica_bgmv( ...@@ -239,7 +239,7 @@ def test_punica_bgmv(
from vllm.lora.ops.bgmv_shrink import _bgmv_shrink_kernel from vllm.lora.ops.bgmv_shrink import _bgmv_shrink_kernel
torch.set_default_device(device) torch.set_default_device(device)
seed_everything(seed) current_platform.seed_everything(seed)
seq_length = 1 seq_length = 1
( (
...@@ -327,7 +327,7 @@ def test_punica_expand_nslices( ...@@ -327,7 +327,7 @@ def test_punica_expand_nslices(
from vllm.lora.ops.bgmv_expand_slice import _bgmv_expand_slice_kernel from vllm.lora.ops.bgmv_expand_slice import _bgmv_expand_slice_kernel
torch.set_default_device(device) torch.set_default_device(device)
seed_everything(seed) current_platform.seed_everything(seed)
seq_length = 128 if op_type == "sgmv" else 1 seq_length = 128 if op_type == "sgmv" else 1
( (
......
...@@ -14,8 +14,8 @@ from vllm.lora.ops.bgmv_shrink import bgmv_shrink ...@@ -14,8 +14,8 @@ from vllm.lora.ops.bgmv_shrink import bgmv_shrink
from vllm.lora.ops.sgmv_expand import sgmv_expand from vllm.lora.ops.sgmv_expand import sgmv_expand
from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice
from vllm.lora.ops.sgmv_shrink import sgmv_shrink from vllm.lora.ops.sgmv_shrink import sgmv_shrink
from vllm.platforms import current_platform
from vllm.triton_utils.libentry import LibEntry from vllm.triton_utils.libentry import LibEntry
from vllm.utils import seed_everything
from .utils import (generate_data, generate_data_for_expand_nslices, from .utils import (generate_data, generate_data_for_expand_nslices,
ref_torch_groupgemm) ref_torch_groupgemm)
...@@ -61,7 +61,7 @@ def test_punica_sgmv( ...@@ -61,7 +61,7 @@ def test_punica_sgmv(
device: str, device: str,
): ):
torch.set_default_device(device) torch.set_default_device(device)
seed_everything(seed) current_platform.seed_everything(seed)
seq_length = 128 seq_length = 128
( (
...@@ -154,7 +154,7 @@ def test_punica_bgmv( ...@@ -154,7 +154,7 @@ def test_punica_bgmv(
from vllm.lora.ops.bgmv_shrink import _bgmv_shrink_kernel from vllm.lora.ops.bgmv_shrink import _bgmv_shrink_kernel
torch.set_default_device(device) torch.set_default_device(device)
seed_everything(seed) current_platform.seed_everything(seed)
seq_length = 1 seq_length = 1
( (
...@@ -242,7 +242,7 @@ def test_punica_expand_nslices( ...@@ -242,7 +242,7 @@ def test_punica_expand_nslices(
from vllm.lora.ops.bgmv_expand_slice import _bgmv_expand_slice_kernel from vllm.lora.ops.bgmv_expand_slice import _bgmv_expand_slice_kernel
torch.set_default_device(device) torch.set_default_device(device)
seed_everything(seed) current_platform.seed_everything(seed)
seq_length = 128 if op_type == "sgmv" else 1 seq_length = 128 if op_type == "sgmv" else 1
( (
......
...@@ -4,11 +4,10 @@ from typing import Any, Dict, Optional ...@@ -4,11 +4,10 @@ from typing import Any, Dict, Optional
import torch import torch
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import seed_everything
def set_random_seed(seed: int) -> None: def set_random_seed(seed: int) -> None:
seed_everything(seed) current_platform.seed_everything(seed)
def set_weight_attrs( def set_weight_attrs(
......
import enum import enum
import random
from typing import NamedTuple, Optional, Tuple, Union from typing import NamedTuple, Optional, Tuple, Union
import numpy as np
import torch import torch
...@@ -111,6 +113,18 @@ class Platform: ...@@ -111,6 +113,18 @@ class Platform:
""" """
return torch.inference_mode(mode=True) return torch.inference_mode(mode=True)
@classmethod
def seed_everything(cls, seed: int) -> None:
"""
Set the seed of each random module.
`torch.manual_seed` will set seed on all devices.
Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
class UnspecifiedPlatform(Platform): class UnspecifiedPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED _enum = PlatformEnum.UNSPECIFIED
...@@ -7,7 +7,6 @@ import gc ...@@ -7,7 +7,6 @@ import gc
import inspect import inspect
import ipaddress import ipaddress
import os import os
import random
import socket import socket
import subprocess import subprocess
import sys import sys
...@@ -331,22 +330,6 @@ def get_cpu_memory() -> int: ...@@ -331,22 +330,6 @@ def get_cpu_memory() -> int:
return psutil.virtual_memory().total return psutil.virtual_memory().total
def seed_everything(seed: int) -> None:
"""
Set the seed of each random module.
Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
"""
random.seed(seed)
np.random.seed(seed)
if current_platform.is_cuda_alike():
torch.cuda.manual_seed_all(seed)
if current_platform.is_xpu():
torch.xpu.manual_seed_all(seed)
def random_uuid() -> str: def random_uuid() -> str:
return str(uuid.uuid4().hex) return str(uuid.uuid4().hex)
...@@ -643,7 +626,7 @@ def create_kv_caches_with_random_flash( ...@@ -643,7 +626,7 @@ def create_kv_caches_with_random_flash(
seed: int = 0, seed: int = 0,
device: Optional[str] = "cuda", device: Optional[str] = "cuda",
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
seed_everything(seed) current_platform.seed_everything(seed)
torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size) key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size)
...@@ -685,7 +668,7 @@ def create_kv_caches_with_random( ...@@ -685,7 +668,7 @@ def create_kv_caches_with_random(
f"Does not support key cache of type fp8 with head_size {head_size}" f"Does not support key cache of type fp8 with head_size {head_size}"
) )
seed_everything(seed) current_platform.seed_everything(seed)
torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment