Unverified Commit a995a773 authored by JieXin Liang's avatar JieXin Liang Committed by GitHub
Browse files

[fix] remove `cuda_device_count_stateless` (#5060)

parent 31035dda
...@@ -18,7 +18,7 @@ from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import ...@@ -18,7 +18,7 @@ from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import
gpu_p2p_access_check, gpu_p2p_access_check,
) )
from sglang.srt.distributed.parallel_state import in_the_same_node_as from sglang.srt.distributed.parallel_state import in_the_same_node_as
from sglang.srt.utils import cuda_device_count_stateless, is_cuda, is_hip from sglang.srt.utils import is_cuda, is_hip
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -217,7 +217,7 @@ class CustomAllreduce: ...@@ -217,7 +217,7 @@ class CustomAllreduce:
if cuda_visible_devices: if cuda_visible_devices:
device_ids = list(map(int, cuda_visible_devices.split(","))) device_ids = list(map(int, cuda_visible_devices.split(",")))
else: else:
device_ids = list(range(cuda_device_count_stateless())) device_ids = list(range(torch.cuda.device_count()))
physical_device_id = device_ids[device.index] physical_device_id = device_ids[device.index]
tensor = torch.tensor([physical_device_id], dtype=torch.int, device="cpu") tensor = torch.tensor([physical_device_id], dtype=torch.int, device="cpu")
......
...@@ -11,11 +11,11 @@ import tempfile ...@@ -11,11 +11,11 @@ import tempfile
from itertools import product from itertools import product
from typing import Dict, List, Optional, Sequence from typing import Dict, List, Optional, Sequence
import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
from sglang.srt.utils import cuda_device_count_stateless
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -218,7 +218,7 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool: ...@@ -218,7 +218,7 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool:
is_distributed = dist.is_initialized() is_distributed = dist.is_initialized()
num_dev = cuda_device_count_stateless() num_dev = torch.cuda.device_count()
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None) cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
if cuda_visible_devices is None: if cuda_visible_devices is None:
cuda_visible_devices = ",".join(str(i) for i in range(num_dev)) cuda_visible_devices = ",".join(str(i) for i in range(num_dev))
......
...@@ -263,7 +263,7 @@ def get_available_gpu_memory(device, gpu_id, distributed=False, empty_cache=True ...@@ -263,7 +263,7 @@ def get_available_gpu_memory(device, gpu_id, distributed=False, empty_cache=True
When distributed is True, the available memory is the minimum available memory of all GPUs. When distributed is True, the available memory is the minimum available memory of all GPUs.
""" """
if device == "cuda": if device == "cuda":
num_gpus = cuda_device_count_stateless() num_gpus = torch.cuda.device_count()
assert gpu_id < num_gpus assert gpu_id < num_gpus
if torch.cuda.current_device() != gpu_id: if torch.cuda.current_device() != gpu_id:
...@@ -1416,47 +1416,6 @@ def disable_request_logging() -> bool: ...@@ -1416,47 +1416,6 @@ def disable_request_logging() -> bool:
return get_bool_env_var("SGLANG_DISABLE_REQUEST_LOGGING") return get_bool_env_var("SGLANG_DISABLE_REQUEST_LOGGING")
@lru_cache(maxsize=8)
def _cuda_device_count_stateless(cuda_visible_devices: Optional[str] = None) -> int:
# Note: cuda_visible_devices is not used, but we keep it as an argument for
# LRU Cache purposes.
# Code below is based on
# https://github.com/pytorch/pytorch/blob/
# c1cd946818442aca8c7f812b16d187ce1586c3bc/
# torch/cuda/__init__.py#L831C1-L831C17
import torch.version
if not torch.cuda._is_compiled():
return 0
if is_hip():
# ROCm uses amdsmi instead of nvml for stateless device count
# This requires a sufficiently modern version of Torch 2.4.0
raw_count = (
torch.cuda._device_count_amdsmi()
if (hasattr(torch.cuda, "_device_count_amdsmi"))
else -1
)
else:
raw_count = torch.cuda._device_count_nvml()
r = torch._C._cuda_getDeviceCount() if raw_count < 0 else raw_count
return r
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/utils.py
def cuda_device_count_stateless() -> int:
"""Get number of CUDA devices, caching based on the value of
CUDA_VISIBLE_DEVICES at the time of call.
This should be used instead of torch.cuda.device_count()
unless CUDA_VISIBLE_DEVICES has already been set to the desired
value."""
# This can be removed and simply replaced with torch.cuda.get_device_count
# after https://github.com/pytorch/pytorch/pull/122815 is released.
return _cuda_device_count_stateless(os.environ.get("CUDA_VISIBLE_DEVICES", None))
def dataclass_to_string_truncated( def dataclass_to_string_truncated(
data, max_length=2048, skip_names: Optional[Set[str]] = None data, max_length=2048, skip_names: Optional[Set[str]] = None
): ):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment