Commit 0dc55ec0 authored by zhuwenwen's avatar zhuwenwen
Browse files

add large_gpu_test

parent 47be5a1c
...@@ -23,7 +23,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs ...@@ -23,7 +23,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.cli_args import make_arg_parser from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.model_executor.model_loader.loader import get_model_loader from vllm.model_executor.model_loader.loader import get_model_loader
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import (FlexibleArgumentParser, cuda_device_count_stateless, from vllm.utils import (FlexibleArgumentParser, GB_bytes,cuda_device_count_stateless,
get_open_port, is_hip) get_open_port, is_hip)
import vllm.envs as envs import vllm.envs as envs
import os import os
...@@ -459,6 +459,36 @@ def fork_new_process_for_each_test( ...@@ -459,6 +459,36 @@ def fork_new_process_for_each_test(
return wrapper return wrapper
def large_gpu_test(*, min_gb: int):
"""
Decorate a test to be skipped if no GPU is available or it does not have
sufficient memory.
Currently, the CI machine uses L4 GPU which has 24 GB VRAM.
"""
try:
if current_platform.is_cpu():
memory_gb = 0
else:
memory_gb = current_platform.get_device_total_memory() / GB_bytes
except Exception as e:
warnings.warn(
f"An error occurred when finding the available memory: {e}",
stacklevel=2,
)
memory_gb = 0
test_skipif = pytest.mark.skipif(
memory_gb < min_gb,
reason=f"Need at least {memory_gb}GB GPU memory to run the test.",
)
def wrapper(f: Callable[_P, None]) -> Callable[_P, None]:
return test_skipif(fork_new_process_for_each_test(f))
return wrapper
def multi_gpu_test(*, num_gpus: int): def multi_gpu_test(*, num_gpus: int):
""" """
......
import psutil
import torch import torch
from .interface import Platform, PlatformEnum from .interface import Platform, PlatformEnum
...@@ -9,6 +10,10 @@ class CpuPlatform(Platform): ...@@ -9,6 +10,10 @@ class CpuPlatform(Platform):
@classmethod @classmethod
def get_device_name(cls, device_id: int = 0) -> str: def get_device_name(cls, device_id: int = 0) -> str:
return "cpu" return "cpu"
@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
return psutil.virtual_memory().total
@classmethod @classmethod
def inference_mode(cls): def inference_mode(cls):
......
...@@ -83,6 +83,11 @@ class Platform: ...@@ -83,6 +83,11 @@ class Platform:
def get_device_name(cls, device_id: int = 0) -> str: def get_device_name(cls, device_id: int = 0) -> str:
raise NotImplementedError raise NotImplementedError
@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
"""Get the total memory of a device in bytes."""
raise NotImplementedError
@classmethod @classmethod
def inference_mode(cls): def inference_mode(cls):
"""A device-specific wrapper of `torch.inference_mode`. """A device-specific wrapper of `torch.inference_mode`.
......
...@@ -29,3 +29,8 @@ class RocmPlatform(Platform): ...@@ -29,3 +29,8 @@ class RocmPlatform(Platform):
@lru_cache(maxsize=8) @lru_cache(maxsize=8)
def get_device_name(cls, device_id: int = 0) -> str: def get_device_name(cls, device_id: int = 0) -> str:
return torch.cuda.get_device_name(device_id) return torch.cuda.get_device_name(device_id)
@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
device_props = torch.cuda.get_device_properties(device_id)
return device_props.total_memory
...@@ -9,6 +9,10 @@ class TpuPlatform(Platform): ...@@ -9,6 +9,10 @@ class TpuPlatform(Platform):
@classmethod @classmethod
def get_device_name(cls, device_id: int = 0) -> str: def get_device_name(cls, device_id: int = 0) -> str:
raise NotImplementedError raise NotImplementedError
@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
raise NotImplementedError
@classmethod @classmethod
def inference_mode(cls): def inference_mode(cls):
......
...@@ -119,6 +119,9 @@ STR_XFORMERS_ATTN_VAL: str = "XFORMERS" ...@@ -119,6 +119,9 @@ STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN" STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
STR_INVALID_VAL: str = "INVALID" STR_INVALID_VAL: str = "INVALID"
GB_bytes = 1_000_000_000
"""The number of bytes in one gigabyte (GB)."""
GiB_bytes = 1 << 30 GiB_bytes = 1 << 30
"""The number of bytes in one gibibyte (GiB).""" """The number of bytes in one gibibyte (GiB)."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment