Unverified Commit 1a820e38 authored by Chaitanya Sri Krishna Lolla's avatar Chaitanya Sri Krishna Lolla Committed by GitHub
Browse files

Remove dependency of pynvml on ROCm (#2995)

parent 0ffcfdf4
...@@ -6,7 +6,6 @@ from contextlib import contextmanager ...@@ -6,7 +6,6 @@ from contextlib import contextmanager
from functools import wraps from functools import wraps
from typing import Callable, List, Optional, TypeVar, Union from typing import Callable, List, Optional, TypeVar, Union
import pynvml
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
...@@ -20,6 +19,14 @@ from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import ...@@ -20,6 +19,14 @@ from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import
from sglang.srt.distributed.parallel_state import in_the_same_node_as from sglang.srt.distributed.parallel_state import in_the_same_node_as
from sglang.srt.utils import cuda_device_count_stateless, is_cuda from sglang.srt.utils import cuda_device_count_stateless, is_cuda
logger = logging.getLogger(__name__)
if is_cuda():
try:
import pynvml
except ImportError as e:
logger.warning("Failed to import pynvml with %r", e)
try: try:
if ops.use_vllm_custom_allreduce: if ops.use_vllm_custom_allreduce:
ops.meta_size() ops.meta_size()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment