Unverified Commit 692db29c authored by Nikita Shapovalov's avatar Nikita Shapovalov Committed by GitHub
Browse files

[Bugfix] Fix Ray compiled-DAG SHM channel stalls by detaching zero-copy...


[Bugfix] Fix Ray compiled-DAG SHM channel stalls by detaching zero-copy `np.ndarray` logprobs buffers (#35736)
Signed-off-by: default avatarNikita Shapovalov <nikita@poolside.ai>
parent 82531edb
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
from vllm.v1.executor.ray_utils import detach_zero_copy_from_model_runner_output
from vllm.v1.outputs import LogprobsLists, LogprobsTensors, ModelRunnerOutput
def _make_readonly(arr: np.ndarray) -> np.ndarray:
arr.setflags(write=False)
return arr
def test_detach_zero_copy_from_model_runner_output_copies_only_numpy_views():
cu_num_generated_tokens = [0, 2]
prompt_logprobs = LogprobsTensors.empty_cpu(1, 2)
output = ModelRunnerOutput(
req_ids=["req-0"],
req_id_to_index={"req-0": 0},
logprobs=LogprobsLists(
logprob_token_ids=_make_readonly(
np.array([[1, 2], [3, 4]], dtype=np.int32)
),
logprobs=_make_readonly(
np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32)
),
sampled_token_ranks=_make_readonly(np.array([1, 2], dtype=np.int32)),
cu_num_generated_tokens=cu_num_generated_tokens,
),
prompt_logprobs_dict={"req-0": prompt_logprobs},
)
original_logprobs = output.logprobs
assert original_logprobs is not None
detach_zero_copy_from_model_runner_output(output)
detached_logprobs = output.logprobs
assert detached_logprobs is not None
assert detached_logprobs is not original_logprobs
assert (
detached_logprobs.logprob_token_ids is not original_logprobs.logprob_token_ids
)
assert detached_logprobs.logprobs is not original_logprobs.logprobs
assert (
detached_logprobs.sampled_token_ranks
is not original_logprobs.sampled_token_ranks
)
assert detached_logprobs.logprob_token_ids.flags.writeable
assert detached_logprobs.logprobs.flags.writeable
assert detached_logprobs.sampled_token_ranks.flags.writeable
assert detached_logprobs.cu_num_generated_tokens is cu_num_generated_tokens
assert output.prompt_logprobs_dict["req-0"] is prompt_logprobs
...@@ -26,6 +26,7 @@ from vllm.v1.executor.ray_utils import ( ...@@ -26,6 +26,7 @@ from vllm.v1.executor.ray_utils import (
WORKER_SPECIFIC_ENV_VARS, WORKER_SPECIFIC_ENV_VARS,
FutureWrapper, FutureWrapper,
RayWorkerWrapper, RayWorkerWrapper,
detach_zero_copy_from_model_runner_output,
initialize_ray_cluster, initialize_ray_cluster,
ray, ray,
) )
...@@ -463,7 +464,9 @@ class RayDistributedExecutor(Executor): ...@@ -463,7 +464,9 @@ class RayDistributedExecutor(Executor):
# Get output only from a single worker (output_rank) # Get output only from a single worker (output_rank)
# When PP is not used, we block here until the result is available. # When PP is not used, we block here until the result is available.
if not non_block: if not non_block:
return refs[0].get() output = refs[0].get()
detach_zero_copy_from_model_runner_output(output)
return output
# When PP is used, we return a FutureWrapper immediately so that # When PP is used, we return a FutureWrapper immediately so that
# the scheduler can yield to the next batch. # the scheduler can yield to the next batch.
...@@ -473,7 +476,10 @@ class RayDistributedExecutor(Executor): ...@@ -473,7 +476,10 @@ class RayDistributedExecutor(Executor):
assert self.kv_output_aggregator is not None assert self.kv_output_aggregator is not None
if not non_block: if not non_block:
# Block and get results from all workers # Block and get results from all workers
return self.kv_output_aggregator.aggregate(ray.get(refs)) outputs = ray.get(refs)
for output in outputs:
detach_zero_copy_from_model_runner_output(output)
return self.kv_output_aggregator.aggregate(outputs)
# Return a future that will aggregate outputs from all workers # Return a future that will aggregate outputs from all workers
return FutureWrapper(refs, self.kv_output_aggregator) return FutureWrapper(refs, self.kv_output_aggregator)
......
...@@ -7,6 +7,8 @@ from collections import defaultdict ...@@ -7,6 +7,8 @@ from collections import defaultdict
from concurrent.futures import Future from concurrent.futures import Future
from typing import TYPE_CHECKING, Union from typing import TYPE_CHECKING, Union
import numpy as np
import vllm.platforms import vllm.platforms
from vllm.config import ParallelConfig from vllm.config import ParallelConfig
from vllm.distributed import get_pp_group from vllm.distributed import get_pp_group
...@@ -190,6 +192,48 @@ except ImportError as e: ...@@ -190,6 +192,48 @@ except ImportError as e:
RayWorkerWrapper = None # type: ignore RayWorkerWrapper = None # type: ignore
def detach_zero_copy_from_model_runner_output(output: "ModelRunnerOutput") -> None:
"""Detach Ray SHM-channel zero-copy buffers from a ModelRunnerOutput in-place.
Ray compiled DAG SHM channels may return zero-copy objects (e.g. `np.ndarray`)
backed by Ray's shared-memory object store. Ray's channel docs explicitly
warn that subsequent reads may block if such an object is still in scope.
vLLM can return numpy-backed logprobs in `ModelRunnerOutput.logprobs`. If
those arrays are backed by Ray SHM (commonly read-only), retaining them in
scope across scheduler iterations can stall the channel and eventually hit
`RAY_CGRAPH_get_timeout`.
Copy read-only numpy arrays so the returned output no longer retains
references to Ray's shared-memory buffers.
We intentionally do not touch `prompt_logprobs_dict`: those entries are
`LogprobsTensors` backed by PyTorch-owned CPU tensors (`to_cpu_nonblocking`
or `empty_cpu`), not NumPy views decoded from Ray channels.
"""
if output.logprobs is None:
return
token_ids, logprobs, ranks, cu_num_generated_tokens = output.logprobs
def _copy_if_readonly(arr):
if isinstance(arr, np.ndarray) and not arr.flags.writeable:
return arr.copy()
return arr
# `cu_num_generated_tokens` is already a plain Python list (or None), so it
# never aliases Ray SHM buffers and can be reused as-is.
token_ids_c = _copy_if_readonly(token_ids)
logprobs_c = _copy_if_readonly(logprobs)
ranks_c = _copy_if_readonly(ranks)
if token_ids_c is token_ids and logprobs_c is logprobs and ranks_c is ranks:
return
output.logprobs = type(output.logprobs)(
token_ids_c, logprobs_c, ranks_c, cu_num_generated_tokens
)
class FutureWrapper(Future): class FutureWrapper(Future):
"""A wrapper around Ray output reference to meet the interface """A wrapper around Ray output reference to meet the interface
of .execute_model(): The top level (core busy loop) expects .result() api of .execute_model(): The top level (core busy loop) expects .result() api
...@@ -207,8 +251,11 @@ class FutureWrapper(Future): ...@@ -207,8 +251,11 @@ class FutureWrapper(Future):
def result(self, timeout=None): def result(self, timeout=None):
outputs = ray.get(self.ref_or_refs, timeout=timeout) outputs = ray.get(self.ref_or_refs, timeout=timeout)
if self.aggregator is None: if self.aggregator is None:
detach_zero_copy_from_model_runner_output(outputs)
return outputs return outputs
for output in outputs:
detach_zero_copy_from_model_runner_output(output)
return self.aggregator.aggregate(outputs, output_rank=0) return self.aggregator.aggregate(outputs, output_rank=0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment