Commit 708d897e authored by zhuwenwen's avatar zhuwenwen
Browse files

Fix multiprocessing shutdown errors

parent 3d087876
...@@ -232,76 +232,91 @@ class LLMEngine: ...@@ -232,76 +232,91 @@ class LLMEngine:
load_config=load_config, load_config=load_config,
) )
if not self.model_config.embedding_mode: init_success = False
self._initialize_kv_caches() try:
if not self.model_config.embedding_mode:
# If usage stat is enabled, collect relevant info. self._initialize_kv_caches()
if is_usage_stats_enabled():
from vllm.model_executor.model_loader import ( # If usage stat is enabled, collect relevant info.
get_architecture_class_name) if is_usage_stats_enabled():
usage_message.report_usage( from vllm.model_executor.model_loader import (
get_architecture_class_name(model_config), get_architecture_class_name)
usage_context, usage_message.report_usage(
extra_kvs={ get_architecture_class_name(model_config),
# Common configuration usage_context,
"dtype": extra_kvs={
str(model_config.dtype), # Common configuration
"tensor_parallel_size": "dtype":
parallel_config.tensor_parallel_size, str(model_config.dtype),
"block_size": "tensor_parallel_size":
cache_config.block_size, parallel_config.tensor_parallel_size,
"gpu_memory_utilization": "block_size":
cache_config.gpu_memory_utilization, cache_config.block_size,
"gpu_memory_utilization":
# Quantization cache_config.gpu_memory_utilization,
"quantization":
model_config.quantization, # Quantization
"kv_cache_dtype": "quantization":
cache_config.cache_dtype, model_config.quantization,
"kv_cache_dtype":
# Feature flags cache_config.cache_dtype,
"enable_lora":
bool(lora_config), # Feature flags
"enable_prefix_caching": "enable_lora":
cache_config.enable_prefix_caching, bool(lora_config),
"enforce_eager": "enable_prefix_caching":
model_config.enforce_eager, cache_config.enable_prefix_caching,
"disable_custom_all_reduce": "enforce_eager":
parallel_config.disable_custom_all_reduce, model_config.enforce_eager,
}) "disable_custom_all_reduce":
parallel_config.disable_custom_all_reduce,
if self.tokenizer: })
# Ping the tokenizer to ensure liveness if it runs in a
# different process. if self.tokenizer:
self.tokenizer.ping() # Ping the tokenizer to ensure liveness if it runs in a
# different process.
# Create the scheduler. self.tokenizer.ping()
# NOTE: the cache_config here have been updated with the numbers of
# GPU and CPU blocks, which are profiled in the distributed executor. # Create the scheduler.
self.scheduler = Scheduler(scheduler_config, cache_config, lora_config) # NOTE: the cache_config here have been updated with the numbers of
# GPU and CPU blocks, which are profiled in the distributed executor.
# Metric Logging. self.scheduler = Scheduler(scheduler_config, cache_config, lora_config)
if self.log_stats:
self.stat_logger = StatLogger( # Metric Logging.
local_interval=_LOCAL_LOGGING_INTERVAL_SEC, if self.log_stats:
labels=dict(model_name=model_config.served_model_name), self.stat_logger = StatLogger(
max_model_len=self.model_config.max_model_len) local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
self.stat_logger.info("cache_config", self.cache_config) labels=dict(model_name=model_config.served_model_name),
max_model_len=self.model_config.max_model_len)
# Create sequence output processor, e.g. for beam search or self.stat_logger.info("cache_config", self.cache_config)
# speculative decoding.
self.output_processor = ( tokenizer_group = self.get_tokenizer_group()
SequenceGroupOutputProcessor.create_output_processor(
self.scheduler_config, def get_tokenizer_for_seq(self,
self.detokenizer, sequence: Sequence) -> "PreTrainedTokenizer":
self.scheduler, return tokenizer_group.get_lora_tokenizer(
self.seq_counter, sequence.lora_request)
self.get_tokenizer_for_seq,
stop_checker=StopChecker( # Create sequence output processor, e.g. for beam search or
self.scheduler_config.max_model_len, # speculative decoding.
self.get_tokenizer_for_seq, self.output_processor = (
), SequenceGroupOutputProcessor.create_output_processor(
)) self.scheduler_config,
self.detokenizer,
self.scheduler,
self.seq_counter,
get_tokenizer_for_seq,
stop_checker=StopChecker(
self.scheduler_config.max_model_len,
get_tokenizer_for_seq,
),
))
init_success = True
finally:
if not init_success:
# Ensure that model_executor is shut down if LLMEngine init
# failed
self.model_executor.shutdown()
def _initialize_kv_caches(self) -> None: def _initialize_kv_caches(self) -> None:
"""Initialize the KV cache in the worker(s). """Initialize the KV cache in the worker(s).
...@@ -390,10 +405,10 @@ class LLMEngine: ...@@ -390,10 +405,10 @@ class LLMEngine:
def get_tokenizer(self) -> "PreTrainedTokenizer": def get_tokenizer(self) -> "PreTrainedTokenizer":
return self.get_tokenizer_group().get_lora_tokenizer(None) return self.get_tokenizer_group().get_lora_tokenizer(None)
def get_tokenizer_for_seq(self, # def get_tokenizer_for_seq(self,
sequence: Sequence) -> "PreTrainedTokenizer": # sequence: Sequence) -> "PreTrainedTokenizer":
return self.get_tokenizer_group().get_lora_tokenizer( # return self.get_tokenizer_group().get_lora_tokenizer(
sequence.lora_request) # sequence.lora_request)
def _init_tokenizer(self, **tokenizer_init_kwargs) -> BaseTokenizerGroup: def _init_tokenizer(self, **tokenizer_init_kwargs) -> BaseTokenizerGroup:
init_kwargs = dict( init_kwargs = dict(
...@@ -782,7 +797,8 @@ class LLMEngine: ...@@ -782,7 +797,8 @@ class LLMEngine:
# Log stats. # Log stats.
self.do_log_stats(scheduler_outputs, output) self.do_log_stats(scheduler_outputs, output)
if not request_outputs: # if not request_outputs:
if not self.has_unfinished_requests():
# Stop the execute model loop in parallel workers until there are # Stop the execute model loop in parallel workers until there are
# more requests to process. This avoids waiting indefinitely in # more requests to process. This avoids waiting indefinitely in
# torch.distributed ops which may otherwise timeout, and unblocks # torch.distributed ops which may otherwise timeout, and unblocks
......
...@@ -76,7 +76,8 @@ class ResultHandler(threading.Thread): ...@@ -76,7 +76,8 @@ class ResultHandler(threading.Thread):
"""Handle results from all workers (in background thread)""" """Handle results from all workers (in background thread)"""
def __init__(self) -> None: def __init__(self) -> None:
super().__init__(daemon=True) super().__init__(daemon=False)
# super().__init__(daemon=True)
self.result_queue = mp.Queue() self.result_queue = mp.Queue()
self.tasks: Dict[uuid.UUID, Union[ResultFuture, asyncio.Future]] = {} self.tasks: Dict[uuid.UUID, Union[ResultFuture, asyncio.Future]] = {}
...@@ -100,7 +101,8 @@ class WorkerMonitor(threading.Thread): ...@@ -100,7 +101,8 @@ class WorkerMonitor(threading.Thread):
def __init__(self, workers: List['ProcessWorkerWrapper'], def __init__(self, workers: List['ProcessWorkerWrapper'],
result_handler: ResultHandler): result_handler: ResultHandler):
super().__init__(daemon=True) super().__init__(daemon=False)
# super().__init__(daemon=True)
self.workers = workers self.workers = workers
self.result_handler = result_handler self.result_handler = result_handler
self._close = False self._close = False
...@@ -112,15 +114,31 @@ class WorkerMonitor(threading.Thread): ...@@ -112,15 +114,31 @@ class WorkerMonitor(threading.Thread):
self._close = True self._close = True
# Kill / cleanup all workers # Kill / cleanup all workers
for worker in self.workers: # for worker in self.workers:
process = worker.process # process = worker.process
if process.sentinel in dead_sentinels: # if process.sentinel in dead_sentinels:
process.join(JOIN_TIMEOUT_S) # process.join(JOIN_TIMEOUT_S)
if process.exitcode is not None and process.exitcode != 0: # if process.exitcode is not None and process.exitcode != 0:
logger.error("Worker %s pid %s died, exit code: %s", # logger.error("Worker %s pid %s died, exit code: %s",
process.name, process.pid, process.exitcode) # process.name, process.pid, process.exitcode)
if not sys.is_finalizing():
# Kill / cleanup all workers
died_count = 0
for worker in self.workers:
process = worker.process
if process.sentinel in dead_sentinels:
process.join(JOIN_TIMEOUT_S)
if process.exitcode is not None and process.exitcode != 0:
died_count += 1
logger.error("Worker %s pid %s died, exit code: %s",
process.name, process.pid,
process.exitcode)
if died_count < len(self.workers):
logger.info(
"Killing remaining local vLLM worker processes")
# Cleanup any remaining workers # Cleanup any remaining workers
logger.info("Killing local vLLM worker processes") # logger.info("Killing local vLLM worker processes")
for worker in self.workers: for worker in self.workers:
worker.kill_worker() worker.kill_worker()
# Must be done after worker task queues are all closed # Must be done after worker task queues are all closed
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment