"vscode:/vscode.git/clone" did not exist on "ecd49ce7e69a50892be7f9841941ca2d7e3b12ea"
Unverified Commit 622b7ab9 authored by wangshuai09's avatar wangshuai09 Committed by GitHub
Browse files

[Hardware] using current_platform.seed_everything (#9785)


Signed-off-by: default avatarwangshuai09 <391746016@qq.com>
parent 09500f7d
...@@ -9,7 +9,8 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask ...@@ -9,7 +9,8 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
from vllm.attention.backends.xformers import _make_alibi_bias from vllm.attention.backends.xformers import _make_alibi_bias
from vllm.attention.ops.prefix_prefill import context_attention_fwd from vllm.attention.ops.prefix_prefill import context_attention_fwd
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, seed_everything from vllm.platforms import current_platform
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
NUM_HEADS = [64] NUM_HEADS = [64]
NUM_QUERIES_PER_KV = [1, 8, 64] NUM_QUERIES_PER_KV = [1, 8, 64]
...@@ -39,7 +40,7 @@ def test_contexted_kv_attention( ...@@ -39,7 +40,7 @@ def test_contexted_kv_attention(
kv_cache_dtype: str, kv_cache_dtype: str,
device: str, device: str,
) -> None: ) -> None:
seed_everything(0) current_platform.seed_everything(0)
torch.set_default_device(device) torch.set_default_device(device)
# Need this, otherwise when we capture the graph the process # Need this, otherwise when we capture the graph the process
...@@ -234,7 +235,7 @@ def test_contexted_kv_attention_alibi( ...@@ -234,7 +235,7 @@ def test_contexted_kv_attention_alibi(
kv_cache_dtype: str, kv_cache_dtype: str,
device: str, device: str,
) -> None: ) -> None:
seed_everything(0) current_platform.seed_everything(0)
torch.set_default_device(device) torch.set_default_device(device)
# Need this, otherwise when we capture the graph the process # Need this, otherwise when we capture the graph the process
......
...@@ -39,7 +39,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope ...@@ -39,7 +39,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding, get_masked_input_and_mask) ParallelLMHead, VocabParallelEmbedding, get_masked_input_and_mask)
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
from vllm.utils import seed_everything from vllm.platforms import current_platform
from .utils import DummyLoRAManager from .utils import DummyLoRAManager
...@@ -923,7 +923,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device, ...@@ -923,7 +923,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
seq_len) -> None: seq_len) -> None:
dtype = torch.float16 dtype = torch.float16
seed = 0 seed = 0
seed_everything(seed) current_platform.seed_everything(seed)
torch.set_default_device(device) torch.set_default_device(device)
punica_wrapper = PunicaWrapper(8192, 256, device) punica_wrapper = PunicaWrapper(8192, 256, device)
max_loras = 8 max_loras = 8
......
""" """
This script is mainly used to tests various hidden_sizes. We have collected the This script is mainly used to tests various hidden_sizes. We have collected the
hidden_sizes included in the LoRA models currently supported by vLLM. It tests hidden_sizes included in the LoRA models currently supported by vLLM. It tests
whether the corresponding Triton kernel can run normally when tensor parallelism whether the corresponding Triton kernel can run normally when tensor parallelism
is set to [1, 2, 4, 8, 16, 32, 64]. is set to [1, 2, 4, 8, 16, 32, 64].
...@@ -15,8 +15,8 @@ from vllm.lora.ops.bgmv_shrink import bgmv_shrink ...@@ -15,8 +15,8 @@ from vllm.lora.ops.bgmv_shrink import bgmv_shrink
from vllm.lora.ops.sgmv_expand import sgmv_expand from vllm.lora.ops.sgmv_expand import sgmv_expand
from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice
from vllm.lora.ops.sgmv_shrink import sgmv_shrink from vllm.lora.ops.sgmv_shrink import sgmv_shrink
from vllm.platforms import current_platform
from vllm.triton_utils.libentry import LibEntry from vllm.triton_utils.libentry import LibEntry
from vllm.utils import seed_everything
from .utils import (generate_data, generate_data_for_expand_nslices, from .utils import (generate_data, generate_data_for_expand_nslices,
ref_torch_groupgemm) ref_torch_groupgemm)
...@@ -146,7 +146,7 @@ def test_punica_sgmv( ...@@ -146,7 +146,7 @@ def test_punica_sgmv(
device: str, device: str,
): ):
torch.set_default_device(device) torch.set_default_device(device)
seed_everything(seed) current_platform.seed_everything(seed)
seq_length = 128 seq_length = 128
( (
...@@ -239,7 +239,7 @@ def test_punica_bgmv( ...@@ -239,7 +239,7 @@ def test_punica_bgmv(
from vllm.lora.ops.bgmv_shrink import _bgmv_shrink_kernel from vllm.lora.ops.bgmv_shrink import _bgmv_shrink_kernel
torch.set_default_device(device) torch.set_default_device(device)
seed_everything(seed) current_platform.seed_everything(seed)
seq_length = 1 seq_length = 1
( (
...@@ -327,7 +327,7 @@ def test_punica_expand_nslices( ...@@ -327,7 +327,7 @@ def test_punica_expand_nslices(
from vllm.lora.ops.bgmv_expand_slice import _bgmv_expand_slice_kernel from vllm.lora.ops.bgmv_expand_slice import _bgmv_expand_slice_kernel
torch.set_default_device(device) torch.set_default_device(device)
seed_everything(seed) current_platform.seed_everything(seed)
seq_length = 128 if op_type == "sgmv" else 1 seq_length = 128 if op_type == "sgmv" else 1
( (
......
""" """
This script is mainly used to test whether trtion kernels can run normally This script is mainly used to test whether trtion kernels can run normally
under different conditions, including various batches, numbers of LoRA , and under different conditions, including various batches, numbers of LoRA , and
maximum ranks. maximum ranks.
""" """
from unittest.mock import patch from unittest.mock import patch
...@@ -14,8 +14,8 @@ from vllm.lora.ops.bgmv_shrink import bgmv_shrink ...@@ -14,8 +14,8 @@ from vllm.lora.ops.bgmv_shrink import bgmv_shrink
from vllm.lora.ops.sgmv_expand import sgmv_expand from vllm.lora.ops.sgmv_expand import sgmv_expand
from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice
from vllm.lora.ops.sgmv_shrink import sgmv_shrink from vllm.lora.ops.sgmv_shrink import sgmv_shrink
from vllm.platforms import current_platform
from vllm.triton_utils.libentry import LibEntry from vllm.triton_utils.libentry import LibEntry
from vllm.utils import seed_everything
from .utils import (generate_data, generate_data_for_expand_nslices, from .utils import (generate_data, generate_data_for_expand_nslices,
ref_torch_groupgemm) ref_torch_groupgemm)
...@@ -61,7 +61,7 @@ def test_punica_sgmv( ...@@ -61,7 +61,7 @@ def test_punica_sgmv(
device: str, device: str,
): ):
torch.set_default_device(device) torch.set_default_device(device)
seed_everything(seed) current_platform.seed_everything(seed)
seq_length = 128 seq_length = 128
( (
...@@ -154,7 +154,7 @@ def test_punica_bgmv( ...@@ -154,7 +154,7 @@ def test_punica_bgmv(
from vllm.lora.ops.bgmv_shrink import _bgmv_shrink_kernel from vllm.lora.ops.bgmv_shrink import _bgmv_shrink_kernel
torch.set_default_device(device) torch.set_default_device(device)
seed_everything(seed) current_platform.seed_everything(seed)
seq_length = 1 seq_length = 1
( (
...@@ -242,7 +242,7 @@ def test_punica_expand_nslices( ...@@ -242,7 +242,7 @@ def test_punica_expand_nslices(
from vllm.lora.ops.bgmv_expand_slice import _bgmv_expand_slice_kernel from vllm.lora.ops.bgmv_expand_slice import _bgmv_expand_slice_kernel
torch.set_default_device(device) torch.set_default_device(device)
seed_everything(seed) current_platform.seed_everything(seed)
seq_length = 128 if op_type == "sgmv" else 1 seq_length = 128 if op_type == "sgmv" else 1
( (
......
...@@ -4,11 +4,10 @@ from typing import Any, Dict, Optional ...@@ -4,11 +4,10 @@ from typing import Any, Dict, Optional
import torch import torch
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import seed_everything
def set_random_seed(seed: int) -> None: def set_random_seed(seed: int) -> None:
seed_everything(seed) current_platform.seed_everything(seed)
def set_weight_attrs( def set_weight_attrs(
......
import enum import enum
import random
from typing import NamedTuple, Optional, Tuple, Union from typing import NamedTuple, Optional, Tuple, Union
import numpy as np
import torch import torch
...@@ -111,6 +113,18 @@ class Platform: ...@@ -111,6 +113,18 @@ class Platform:
""" """
return torch.inference_mode(mode=True) return torch.inference_mode(mode=True)
@classmethod
def seed_everything(cls, seed: int) -> None:
"""
Set the seed of each random module.
`torch.manual_seed` will set seed on all devices.
Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
class UnspecifiedPlatform(Platform): class UnspecifiedPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED _enum = PlatformEnum.UNSPECIFIED
...@@ -7,7 +7,6 @@ import gc ...@@ -7,7 +7,6 @@ import gc
import inspect import inspect
import ipaddress import ipaddress
import os import os
import random
import socket import socket
import subprocess import subprocess
import sys import sys
...@@ -331,22 +330,6 @@ def get_cpu_memory() -> int: ...@@ -331,22 +330,6 @@ def get_cpu_memory() -> int:
return psutil.virtual_memory().total return psutil.virtual_memory().total
def seed_everything(seed: int) -> None:
"""
Set the seed of each random module.
Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
"""
random.seed(seed)
np.random.seed(seed)
if current_platform.is_cuda_alike():
torch.cuda.manual_seed_all(seed)
if current_platform.is_xpu():
torch.xpu.manual_seed_all(seed)
def random_uuid() -> str: def random_uuid() -> str:
return str(uuid.uuid4().hex) return str(uuid.uuid4().hex)
...@@ -643,7 +626,7 @@ def create_kv_caches_with_random_flash( ...@@ -643,7 +626,7 @@ def create_kv_caches_with_random_flash(
seed: int = 0, seed: int = 0,
device: Optional[str] = "cuda", device: Optional[str] = "cuda",
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
seed_everything(seed) current_platform.seed_everything(seed)
torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size) key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size)
...@@ -685,7 +668,7 @@ def create_kv_caches_with_random( ...@@ -685,7 +668,7 @@ def create_kv_caches_with_random(
f"Does not support key cache of type fp8 with head_size {head_size}" f"Does not support key cache of type fp8 with head_size {head_size}"
) )
seed_everything(seed) current_platform.seed_everything(seed)
torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment