Unverified Commit 198d6a28 authored by Rui Qiao's avatar Rui Qiao Committed by GitHub
Browse files

[Core] Shut down aDAG workers with clean async llm engine exit (#7224)


Signed-off-by: default avatarRui Qiao <ruisearch42@gmail.com>
parent 774cd1d3
...@@ -34,9 +34,6 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME, ...@@ -34,9 +34,6 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME,
pytest.skip("Skipping multi-node pipeline parallel test for " pytest.skip("Skipping multi-node pipeline parallel test for "
"multiprocessing distributed backend") "multiprocessing distributed backend")
USE_RAY_ADAG_NCCL = 0
USE_RAY_ADAG = 0
pp_args = [ pp_args = [
# use half precision for speed and memory savings in CI environment # use half precision for speed and memory savings in CI environment
"--dtype", "--dtype",
...@@ -70,14 +67,13 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME, ...@@ -70,14 +67,13 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME,
pp_args.append("--enforce-eager") pp_args.append("--enforce-eager")
tp_args.append("--enforce-eager") tp_args.append("--enforce-eager")
pp_env = None pp_env = None
if USE_RAY_ADAG: if (DIST_BACKEND == "ray" and TP_SIZE == 2 and PP_SIZE == 2
assert DIST_BACKEND == "ray", ( and CHUNKED_PREFILL):
"Ray ADAG is only supported with Ray distributed backend") # Test Ray ADAG for a subset of the tests
pp_env = { pp_env = {
"VLLM_USE_RAY_COMPILED_DAG": "1", "VLLM_USE_RAY_COMPILED_DAG": "1",
"VLLM_USE_RAY_SPMD_WORKER": "1", "VLLM_USE_RAY_SPMD_WORKER": "1",
"VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL": "VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL": "1",
str(int(USE_RAY_ADAG_NCCL)),
} }
compare_two_settings(MODEL_NAME, pp_args, tp_args, pp_env) compare_two_settings(MODEL_NAME, pp_args, tp_args, pp_env)
......
...@@ -661,6 +661,20 @@ class AsyncLLMEngine: ...@@ -661,6 +661,20 @@ class AsyncLLMEngine:
partial(_log_task_completion, error_callback=self._error_callback)) partial(_log_task_completion, error_callback=self._error_callback))
self.background_loop = asyncio.shield(self._background_loop_unshielded) self.background_loop = asyncio.shield(self._background_loop_unshielded)
def shutdown_background_loop(self) -> None:
"""
Shut down the background loop.
This method needs to be called during cleanup to remove
references to `self` and properly GC the resources held
by the async LLM engine (e.g., the executors as well as
their resources).
"""
if self._background_loop_unshielded is not None:
self._background_loop_unshielded.cancel()
self._background_loop_unshielded = None
self.background_loop = None
def _init_engine(self, *args, def _init_engine(self, *args,
**kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]: **kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]:
if not self.engine_use_ray: if not self.engine_use_ray:
......
...@@ -245,9 +245,18 @@ class LLMEngine: ...@@ -245,9 +245,18 @@ class LLMEngine:
if not self.model_config.skip_tokenizer_init: if not self.model_config.skip_tokenizer_init:
self.tokenizer = self._init_tokenizer() self.tokenizer = self._init_tokenizer()
self.detokenizer = Detokenizer(self.tokenizer) self.detokenizer = Detokenizer(self.tokenizer)
tokenizer_group = self.get_tokenizer_group()
else: else:
self.tokenizer = None self.tokenizer = None
self.detokenizer = None self.detokenizer = None
tokenizer_group = None
# Ensure that the function doesn't contain a reference to self,
# to avoid engine GC issues
def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
assert tokenizer_group, ("tokenizer_group cannot be None, "
"make sure skip_tokenizer_init is False")
return tokenizer_group.get_lora_tokenizer(sequence.lora_request)
self.seq_counter = Counter() self.seq_counter = Counter()
self.generation_config_fields = _load_generation_config_dict( self.generation_config_fields = _load_generation_config_dict(
...@@ -356,10 +365,10 @@ class LLMEngine: ...@@ -356,10 +365,10 @@ class LLMEngine:
self.detokenizer, self.detokenizer,
self.scheduler, self.scheduler,
self.seq_counter, self.seq_counter,
self.get_tokenizer_for_seq, get_tokenizer_for_seq,
stop_checker=StopChecker( stop_checker=StopChecker(
self.scheduler_config.max_model_len, self.scheduler_config.max_model_len,
self.get_tokenizer_for_seq, get_tokenizer_for_seq,
), ),
)) ))
...@@ -491,10 +500,6 @@ class LLMEngine: ...@@ -491,10 +500,6 @@ class LLMEngine:
) -> AnyTokenizer: ) -> AnyTokenizer:
return self.get_tokenizer_group().get_lora_tokenizer(lora_request) return self.get_tokenizer_group().get_lora_tokenizer(lora_request)
def get_tokenizer_for_seq(self, sequence: Sequence) -> AnyTokenizer:
return self.get_tokenizer_group().get_lora_tokenizer(
sequence.lora_request)
def _init_tokenizer(self) -> BaseTokenizerGroup: def _init_tokenizer(self) -> BaseTokenizerGroup:
return init_tokenizer_from_configs( return init_tokenizer_from_configs(
model_config=self.model_config, model_config=self.model_config,
......
...@@ -36,6 +36,7 @@ class AsyncEngineRPCServer: ...@@ -36,6 +36,7 @@ class AsyncEngineRPCServer:
"""Cleanup all resources.""" """Cleanup all resources."""
self.socket.close() self.socket.close()
self.context.destroy() self.context.destroy()
self.engine.shutdown_background_loop()
async def get_model_config(self, identity): async def get_model_config(self, identity):
"""Send the ModelConfig""" """Send the ModelConfig"""
......
...@@ -60,6 +60,14 @@ class RayGPUExecutor(DistributedGPUExecutor): ...@@ -60,6 +60,14 @@ class RayGPUExecutor(DistributedGPUExecutor):
# Create the parallel GPU workers. # Create the parallel GPU workers.
self._init_workers_ray(placement_group) self._init_workers_ray(placement_group)
def shutdown(self) -> None:
if hasattr(self, "forward_dag") and self.forward_dag is not None:
self.forward_dag.teardown()
import ray
for worker in self.workers:
ray.kill(worker)
self.forward_dag = None
def _configure_ray_workers_use_nsight(self, def _configure_ray_workers_use_nsight(self,
ray_remote_kwargs) -> Dict[str, Any]: ray_remote_kwargs) -> Dict[str, Any]:
# If nsight profiling is enabled, we need to set the profiling # If nsight profiling is enabled, we need to set the profiling
...@@ -117,7 +125,6 @@ class RayGPUExecutor(DistributedGPUExecutor): ...@@ -117,7 +125,6 @@ class RayGPUExecutor(DistributedGPUExecutor):
logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker) logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker)
# Create the workers. # Create the workers.
driver_ip = get_ip() driver_ip = get_ip()
logger.info("driver_ip: %s", driver_ip)
worker_wrapper_kwargs = self._get_worker_wrapper_args() worker_wrapper_kwargs = self._get_worker_wrapper_args()
for bundle_id, bundle in enumerate(placement_group.bundle_specs): for bundle_id, bundle in enumerate(placement_group.bundle_specs):
if not bundle.get("GPU", 0): if not bundle.get("GPU", 0):
...@@ -446,11 +453,7 @@ class RayGPUExecutor(DistributedGPUExecutor): ...@@ -446,11 +453,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
return forward_dag.experimental_compile(enable_asyncio=enable_asyncio) return forward_dag.experimental_compile(enable_asyncio=enable_asyncio)
def __del__(self): def __del__(self):
if self.forward_dag is not None: self.shutdown()
self.forward_dag.teardown()
import ray
for worker in self.workers:
ray.kill(worker)
class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync): class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
...@@ -523,8 +526,4 @@ class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync): ...@@ -523,8 +526,4 @@ class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
return await asyncio.gather(*coros) return await asyncio.gather(*coros)
def __del__(self): def __del__(self):
if self.forward_dag is not None: self.shutdown()
self.forward_dag.teardown()
import ray
for worker in self.workers:
ray.kill(worker)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment