Commit f2e624d6 authored by xiabo's avatar xiabo
Browse files

已修改 vllm/utils/__init__.py

parent 8e95b5e2
......@@ -1173,12 +1173,11 @@ def find_nccl_library() -> str:
prev_set_stream = torch.cuda.set_stream
_current_stream = None
_current_stream_tls = threading.local()
def _patched_set_stream(stream: torch.cuda.Stream) -> None:
global _current_stream
_current_stream = stream
_current_stream_tls.value = stream
prev_set_stream(stream)
......@@ -1197,8 +1196,7 @@ def current_stream() -> torch.cuda.Stream:
from C/C++ code.
"""
from vllm.platforms import current_platform
global _current_stream
if _current_stream is None:
if not hasattr(_current_stream_tls,"value") or _current_stream_tls.value is None:
# when this function is called before any stream is set,
# we return the default stream.
# On ROCm using the default 0 stream in combination with RCCL
......@@ -1208,8 +1206,8 @@ def current_stream() -> torch.cuda.Stream:
# fix computational precision issue in eager mode
# _current_stream = torch.cuda.Stream() if current_platform.is_rocm(
# ) else torch.cuda.current_stream()
_current_stream = torch.cuda.current_stream()
return _current_stream
_current_stream_tls.value = torch.cuda.current_stream()
return _current_stream_tls.value
def enable_trace_function_call_for_thread(vllm_config: VllmConfig) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment