Unverified Commit 622b7ab9 authored by wangshuai09's avatar wangshuai09 Committed by GitHub
Browse files

[Hardware] using current_platform.seed_everything (#9785)


Signed-off-by: default avatarwangshuai09 <391746016@qq.com>
parent 09500f7d
......@@ -9,7 +9,8 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
from vllm.attention.backends.xformers import _make_alibi_bias
from vllm.attention.ops.prefix_prefill import context_attention_fwd
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, seed_everything
from vllm.platforms import current_platform
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
NUM_HEADS = [64]
NUM_QUERIES_PER_KV = [1, 8, 64]
......@@ -39,7 +40,7 @@ def test_contexted_kv_attention(
kv_cache_dtype: str,
device: str,
) -> None:
seed_everything(0)
current_platform.seed_everything(0)
torch.set_default_device(device)
# Need this, otherwise when we capture the graph the process
......@@ -234,7 +235,7 @@ def test_contexted_kv_attention_alibi(
kv_cache_dtype: str,
device: str,
) -> None:
seed_everything(0)
current_platform.seed_everything(0)
torch.set_default_device(device)
# Need this, otherwise when we capture the graph the process
......
......@@ -39,7 +39,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding, get_masked_input_and_mask)
from vllm.model_executor.utils import set_random_seed
from vllm.utils import seed_everything
from vllm.platforms import current_platform
from .utils import DummyLoRAManager
......@@ -923,7 +923,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
seq_len) -> None:
dtype = torch.float16
seed = 0
seed_everything(seed)
current_platform.seed_everything(seed)
torch.set_default_device(device)
punica_wrapper = PunicaWrapper(8192, 256, device)
max_loras = 8
......
"""
This script is mainly used to tests various hidden_sizes. We have collected the
This script is mainly used to tests various hidden_sizes. We have collected the
hidden_sizes included in the LoRA models currently supported by vLLM. It tests
whether the corresponding Triton kernel can run normally when tensor parallelism
is set to [1, 2, 4, 8, 16, 32, 64].
......@@ -15,8 +15,8 @@ from vllm.lora.ops.bgmv_shrink import bgmv_shrink
from vllm.lora.ops.sgmv_expand import sgmv_expand
from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice
from vllm.lora.ops.sgmv_shrink import sgmv_shrink
from vllm.platforms import current_platform
from vllm.triton_utils.libentry import LibEntry
from vllm.utils import seed_everything
from .utils import (generate_data, generate_data_for_expand_nslices,
ref_torch_groupgemm)
......@@ -146,7 +146,7 @@ def test_punica_sgmv(
device: str,
):
torch.set_default_device(device)
seed_everything(seed)
current_platform.seed_everything(seed)
seq_length = 128
(
......@@ -239,7 +239,7 @@ def test_punica_bgmv(
from vllm.lora.ops.bgmv_shrink import _bgmv_shrink_kernel
torch.set_default_device(device)
seed_everything(seed)
current_platform.seed_everything(seed)
seq_length = 1
(
......@@ -327,7 +327,7 @@ def test_punica_expand_nslices(
from vllm.lora.ops.bgmv_expand_slice import _bgmv_expand_slice_kernel
torch.set_default_device(device)
seed_everything(seed)
current_platform.seed_everything(seed)
seq_length = 128 if op_type == "sgmv" else 1
(
......
"""
This script is mainly used to test whether trtion kernels can run normally
under different conditions, including various batches, numbers of LoRA , and
This script is mainly used to test whether trtion kernels can run normally
under different conditions, including various batches, numbers of LoRA , and
maximum ranks.
"""
from unittest.mock import patch
......@@ -14,8 +14,8 @@ from vllm.lora.ops.bgmv_shrink import bgmv_shrink
from vllm.lora.ops.sgmv_expand import sgmv_expand
from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice
from vllm.lora.ops.sgmv_shrink import sgmv_shrink
from vllm.platforms import current_platform
from vllm.triton_utils.libentry import LibEntry
from vllm.utils import seed_everything
from .utils import (generate_data, generate_data_for_expand_nslices,
ref_torch_groupgemm)
......@@ -61,7 +61,7 @@ def test_punica_sgmv(
device: str,
):
torch.set_default_device(device)
seed_everything(seed)
current_platform.seed_everything(seed)
seq_length = 128
(
......@@ -154,7 +154,7 @@ def test_punica_bgmv(
from vllm.lora.ops.bgmv_shrink import _bgmv_shrink_kernel
torch.set_default_device(device)
seed_everything(seed)
current_platform.seed_everything(seed)
seq_length = 1
(
......@@ -242,7 +242,7 @@ def test_punica_expand_nslices(
from vllm.lora.ops.bgmv_expand_slice import _bgmv_expand_slice_kernel
torch.set_default_device(device)
seed_everything(seed)
current_platform.seed_everything(seed)
seq_length = 128 if op_type == "sgmv" else 1
(
......
......@@ -4,11 +4,10 @@ from typing import Any, Dict, Optional
import torch
from vllm.platforms import current_platform
from vllm.utils import seed_everything
def set_random_seed(seed: int) -> None:
seed_everything(seed)
current_platform.seed_everything(seed)
def set_weight_attrs(
......
import enum
import random
from typing import NamedTuple, Optional, Tuple, Union
import numpy as np
import torch
......@@ -111,6 +113,18 @@ class Platform:
"""
return torch.inference_mode(mode=True)
@classmethod
def seed_everything(cls, seed: int) -> None:
"""
Set the seed of each random module.
`torch.manual_seed` will set seed on all devices.
Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
class UnspecifiedPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED
......@@ -7,7 +7,6 @@ import gc
import inspect
import ipaddress
import os
import random
import socket
import subprocess
import sys
......@@ -331,22 +330,6 @@ def get_cpu_memory() -> int:
return psutil.virtual_memory().total
def seed_everything(seed: int) -> None:
"""
Set the seed of each random module.
Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
"""
random.seed(seed)
np.random.seed(seed)
if current_platform.is_cuda_alike():
torch.cuda.manual_seed_all(seed)
if current_platform.is_xpu():
torch.xpu.manual_seed_all(seed)
def random_uuid() -> str:
return str(uuid.uuid4().hex)
......@@ -643,7 +626,7 @@ def create_kv_caches_with_random_flash(
seed: int = 0,
device: Optional[str] = "cuda",
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
seed_everything(seed)
current_platform.seed_everything(seed)
torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size)
......@@ -685,7 +668,7 @@ def create_kv_caches_with_random(
f"Does not support key cache of type fp8 with head_size {head_size}"
)
seed_everything(seed)
current_platform.seed_everything(seed)
torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment