Unverified Commit 88016c37 authored by Li, Jiang's avatar Li, Jiang Committed by GitHub
Browse files

[Bugfix] Fix pooling models on CPU backend (#23392)


Signed-off-by: default avatarjiang1.li <jiang1.li@intel.com>
parent 99872085
...@@ -1440,6 +1440,12 @@ def _patched_set_stream(stream: torch.cuda.Stream) -> None: ...@@ -1440,6 +1440,12 @@ def _patched_set_stream(stream: torch.cuda.Stream) -> None:
torch.cuda.set_stream = _patched_set_stream torch.cuda.set_stream = _patched_set_stream
class _StreamPlaceholder:
def __init__(self):
self.synchronize = lambda: None
def current_stream() -> torch.cuda.Stream: def current_stream() -> torch.cuda.Stream:
""" """
replace `torch.cuda.current_stream()` with `vllm.utils.current_stream()`. replace `torch.cuda.current_stream()` with `vllm.utils.current_stream()`.
...@@ -1459,8 +1465,18 @@ def current_stream() -> torch.cuda.Stream: ...@@ -1459,8 +1465,18 @@ def current_stream() -> torch.cuda.Stream:
# On ROCm using the default 0 stream in combination with RCCL # On ROCm using the default 0 stream in combination with RCCL
# is hurting performance. Therefore creating a dedicated stream # is hurting performance. Therefore creating a dedicated stream
# per process # per process
_current_stream_tls.value = torch.cuda.Stream( if current_platform.is_rocm():
) if current_platform.is_rocm() else torch.cuda.current_stream() _current_stream_tls.value = torch.cuda.Stream()
elif current_platform.is_cpu():
_current_stream_tls.value = _StreamPlaceholder()
else:
current_stream = current_platform.current_stream
if current_stream is not None:
_current_stream_tls.value = current_stream()
else:
raise ValueError(
"Fail to set current stream, current platform "
"may not support current_stream with torch API")
return _current_stream_tls.value return _current_stream_tls.value
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment