Commit a15d668b authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev-pyncclstream' into 'v0.9.2-dev'

已修改        vllm/utils/__init__.py

See merge request dcutoolkit/deeplearing/vllm!171
parents 88dbf92c 20b6cf64
...@@ -1173,12 +1173,11 @@ def find_nccl_library() -> str: ...@@ -1173,12 +1173,11 @@ def find_nccl_library() -> str:
prev_set_stream = torch.cuda.set_stream prev_set_stream = torch.cuda.set_stream
_current_stream = None _current_stream_tls = threading.local()
def _patched_set_stream(stream: torch.cuda.Stream) -> None: def _patched_set_stream(stream: torch.cuda.Stream) -> None:
global _current_stream _current_stream_tls.value = stream
_current_stream = stream
prev_set_stream(stream) prev_set_stream(stream)
...@@ -1197,8 +1196,8 @@ def current_stream() -> torch.cuda.Stream: ...@@ -1197,8 +1196,8 @@ def current_stream() -> torch.cuda.Stream:
from C/C++ code. from C/C++ code.
""" """
from vllm.platforms import current_platform from vllm.platforms import current_platform
global _current_stream if not hasattr(_current_stream_tls,
if _current_stream is None: "value") or _current_stream_tls.value is None:
# when this function is called before any stream is set, # when this function is called before any stream is set,
# we return the default stream. # we return the default stream.
# On ROCm using the default 0 stream in combination with RCCL # On ROCm using the default 0 stream in combination with RCCL
...@@ -1208,8 +1207,9 @@ def current_stream() -> torch.cuda.Stream: ...@@ -1208,8 +1207,9 @@ def current_stream() -> torch.cuda.Stream:
# fix computational precision issue in eager mode # fix computational precision issue in eager mode
# _current_stream = torch.cuda.Stream() if current_platform.is_rocm( # _current_stream = torch.cuda.Stream() if current_platform.is_rocm(
# ) else torch.cuda.current_stream() # ) else torch.cuda.current_stream()
_current_stream = torch.cuda.current_stream() _current_stream_tls.value = torch.cuda.Stream(
return _current_stream ) if current_platform.is_rocm() else torch.cuda.current_stream()
return _current_stream_tls.value
def enable_trace_function_call_for_thread(vllm_config: VllmConfig) -> None: def enable_trace_function_call_for_thread(vllm_config: VllmConfig) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment