Commit 9ff3592b authored by zhuwenwen's avatar zhuwenwen
Browse files

Revert "已修改 vllm/utils/__init__.py"

This reverts commit 20b6cf64.
parent dcec1430
...@@ -1173,11 +1173,12 @@ def find_nccl_library() -> str: ...@@ -1173,11 +1173,12 @@ def find_nccl_library() -> str:
prev_set_stream = torch.cuda.set_stream prev_set_stream = torch.cuda.set_stream
_current_stream_tls = threading.local() _current_stream = None
def _patched_set_stream(stream: torch.cuda.Stream) -> None: def _patched_set_stream(stream: torch.cuda.Stream) -> None:
_current_stream_tls.value = stream global _current_stream
_current_stream = stream
prev_set_stream(stream) prev_set_stream(stream)
...@@ -1196,8 +1197,8 @@ def current_stream() -> torch.cuda.Stream: ...@@ -1196,8 +1197,8 @@ def current_stream() -> torch.cuda.Stream:
from C/C++ code. from C/C++ code.
""" """
from vllm.platforms import current_platform from vllm.platforms import current_platform
if not hasattr(_current_stream_tls, global _current_stream
"value") or _current_stream_tls.value is None: if _current_stream is None:
# when this function is called before any stream is set, # when this function is called before any stream is set,
# we return the default stream. # we return the default stream.
# On ROCm using the default 0 stream in combination with RCCL # On ROCm using the default 0 stream in combination with RCCL
...@@ -1207,9 +1208,8 @@ def current_stream() -> torch.cuda.Stream: ...@@ -1207,9 +1208,8 @@ def current_stream() -> torch.cuda.Stream:
# fix computational precision issue in eager mode # fix computational precision issue in eager mode
# _current_stream = torch.cuda.Stream() if current_platform.is_rocm( # _current_stream = torch.cuda.Stream() if current_platform.is_rocm(
# ) else torch.cuda.current_stream() # ) else torch.cuda.current_stream()
_current_stream_tls.value = torch.cuda.Stream( _current_stream = torch.cuda.current_stream()
) if current_platform.is_rocm() else torch.cuda.current_stream() return _current_stream
return _current_stream_tls.value
def enable_trace_function_call_for_thread(vllm_config: VllmConfig) -> None: def enable_trace_function_call_for_thread(vllm_config: VllmConfig) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment