"include/vscode:/vscode.git/clone" did not exist on "10732847e73496e59f398d894c50dd9a920f1bd4"
Unverified Commit 70b81c4f authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[bugfix][async scheduling] fix extra cuda context in device 0 with EP/DP (#37449)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
parent 7476d148
......@@ -597,17 +597,6 @@ class WorkerProc:
wrapper.init_worker(all_kwargs)
self.worker = wrapper
scheduler_config = vllm_config.scheduler_config
self.use_async_scheduling = scheduler_config.async_scheduling
if self.use_async_scheduling:
self.async_output_queue: queue.Queue = queue.Queue()
self.async_output_copy_thread = Thread(
target=self.async_output_busy_loop,
daemon=True,
name="WorkerAsyncOutputCopy",
)
self.async_output_copy_thread.start()
self.setup_proc_title_and_log_prefix(
enable_ep=vllm_config.parallel_config.enable_expert_parallel
)
......@@ -622,6 +611,17 @@ class WorkerProc:
)
self.worker.load_model()
scheduler_config = vllm_config.scheduler_config
self.use_async_scheduling = scheduler_config.async_scheduling
if self.use_async_scheduling:
self.async_output_queue: queue.Queue = queue.Queue()
self.async_output_copy_thread = Thread(
target=self.async_output_busy_loop,
daemon=True,
name="WorkerAsyncOutputCopy",
)
self.async_output_copy_thread.start()
# Set block size based on the attention backends
current_platform.update_block_size_for_backend(vllm_config)
......@@ -911,6 +911,18 @@ class WorkerProc:
def async_output_busy_loop(self):
"""Entrypoint for the thread which handles outputs asynchronously."""
# set device to the worker device for the thread.
# a thread will not inherit the context of the main thread.
# when calling any cuda runtime functions, it will implicitly
# create a new cuda context on device 0, consuming extra memory.
# here we set the device to the worker device for the thread,
# enforcing the context to be the same as the main thread.
from vllm.platforms import current_platform
if hasattr(self.worker, "device"):
current_platform.set_device(self.worker.device)
while True:
output = self.async_output_queue.get()
self.enqueue_output(output)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment