Commit 708d897e authored by zhuwenwen's avatar zhuwenwen
Browse files

Fix multiprocessing shutdown errors

parent 3d087876
......@@ -232,76 +232,91 @@ class LLMEngine:
load_config=load_config,
)
if not self.model_config.embedding_mode:
self._initialize_kv_caches()
# If usage stat is enabled, collect relevant info.
if is_usage_stats_enabled():
from vllm.model_executor.model_loader import (
get_architecture_class_name)
usage_message.report_usage(
get_architecture_class_name(model_config),
usage_context,
extra_kvs={
# Common configuration
"dtype":
str(model_config.dtype),
"tensor_parallel_size":
parallel_config.tensor_parallel_size,
"block_size":
cache_config.block_size,
"gpu_memory_utilization":
cache_config.gpu_memory_utilization,
# Quantization
"quantization":
model_config.quantization,
"kv_cache_dtype":
cache_config.cache_dtype,
# Feature flags
"enable_lora":
bool(lora_config),
"enable_prefix_caching":
cache_config.enable_prefix_caching,
"enforce_eager":
model_config.enforce_eager,
"disable_custom_all_reduce":
parallel_config.disable_custom_all_reduce,
})
if self.tokenizer:
# Ping the tokenizer to ensure liveness if it runs in a
# different process.
self.tokenizer.ping()
# Create the scheduler.
# NOTE: the cache_config here have been updated with the numbers of
# GPU and CPU blocks, which are profiled in the distributed executor.
self.scheduler = Scheduler(scheduler_config, cache_config, lora_config)
# Metric Logging.
if self.log_stats:
self.stat_logger = StatLogger(
local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
labels=dict(model_name=model_config.served_model_name),
max_model_len=self.model_config.max_model_len)
self.stat_logger.info("cache_config", self.cache_config)
# Create sequence output processor, e.g. for beam search or
# speculative decoding.
self.output_processor = (
SequenceGroupOutputProcessor.create_output_processor(
self.scheduler_config,
self.detokenizer,
self.scheduler,
self.seq_counter,
self.get_tokenizer_for_seq,
stop_checker=StopChecker(
self.scheduler_config.max_model_len,
self.get_tokenizer_for_seq,
),
))
init_success = False
try:
if not self.model_config.embedding_mode:
self._initialize_kv_caches()
# If usage stat is enabled, collect relevant info.
if is_usage_stats_enabled():
from vllm.model_executor.model_loader import (
get_architecture_class_name)
usage_message.report_usage(
get_architecture_class_name(model_config),
usage_context,
extra_kvs={
# Common configuration
"dtype":
str(model_config.dtype),
"tensor_parallel_size":
parallel_config.tensor_parallel_size,
"block_size":
cache_config.block_size,
"gpu_memory_utilization":
cache_config.gpu_memory_utilization,
# Quantization
"quantization":
model_config.quantization,
"kv_cache_dtype":
cache_config.cache_dtype,
# Feature flags
"enable_lora":
bool(lora_config),
"enable_prefix_caching":
cache_config.enable_prefix_caching,
"enforce_eager":
model_config.enforce_eager,
"disable_custom_all_reduce":
parallel_config.disable_custom_all_reduce,
})
if self.tokenizer:
# Ping the tokenizer to ensure liveness if it runs in a
# different process.
self.tokenizer.ping()
# Create the scheduler.
# NOTE: the cache_config here have been updated with the numbers of
# GPU and CPU blocks, which are profiled in the distributed executor.
self.scheduler = Scheduler(scheduler_config, cache_config, lora_config)
# Metric Logging.
if self.log_stats:
self.stat_logger = StatLogger(
local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
labels=dict(model_name=model_config.served_model_name),
max_model_len=self.model_config.max_model_len)
self.stat_logger.info("cache_config", self.cache_config)
tokenizer_group = self.get_tokenizer_group()
def get_tokenizer_for_seq(self,
sequence: Sequence) -> "PreTrainedTokenizer":
return tokenizer_group.get_lora_tokenizer(
sequence.lora_request)
# Create sequence output processor, e.g. for beam search or
# speculative decoding.
self.output_processor = (
SequenceGroupOutputProcessor.create_output_processor(
self.scheduler_config,
self.detokenizer,
self.scheduler,
self.seq_counter,
get_tokenizer_for_seq,
stop_checker=StopChecker(
self.scheduler_config.max_model_len,
get_tokenizer_for_seq,
),
))
init_success = True
finally:
if not init_success:
# Ensure that model_executor is shut down if LLMEngine init
# failed
self.model_executor.shutdown()
def _initialize_kv_caches(self) -> None:
"""Initialize the KV cache in the worker(s).
......@@ -390,10 +405,10 @@ class LLMEngine:
def get_tokenizer(self) -> "PreTrainedTokenizer":
return self.get_tokenizer_group().get_lora_tokenizer(None)
def get_tokenizer_for_seq(self,
sequence: Sequence) -> "PreTrainedTokenizer":
return self.get_tokenizer_group().get_lora_tokenizer(
sequence.lora_request)
# def get_tokenizer_for_seq(self,
# sequence: Sequence) -> "PreTrainedTokenizer":
# return self.get_tokenizer_group().get_lora_tokenizer(
# sequence.lora_request)
def _init_tokenizer(self, **tokenizer_init_kwargs) -> BaseTokenizerGroup:
init_kwargs = dict(
......@@ -782,7 +797,8 @@ class LLMEngine:
# Log stats.
self.do_log_stats(scheduler_outputs, output)
if not request_outputs:
# if not request_outputs:
if not self.has_unfinished_requests():
# Stop the execute model loop in parallel workers until there are
# more requests to process. This avoids waiting indefinitely in
# torch.distributed ops which may otherwise timeout, and unblocks
......
......@@ -76,7 +76,8 @@ class ResultHandler(threading.Thread):
"""Handle results from all workers (in background thread)"""
def __init__(self) -> None:
super().__init__(daemon=True)
super().__init__(daemon=False)
# super().__init__(daemon=True)
self.result_queue = mp.Queue()
self.tasks: Dict[uuid.UUID, Union[ResultFuture, asyncio.Future]] = {}
......@@ -100,7 +101,8 @@ class WorkerMonitor(threading.Thread):
def __init__(self, workers: List['ProcessWorkerWrapper'],
result_handler: ResultHandler):
super().__init__(daemon=True)
super().__init__(daemon=False)
# super().__init__(daemon=True)
self.workers = workers
self.result_handler = result_handler
self._close = False
......@@ -112,15 +114,31 @@ class WorkerMonitor(threading.Thread):
self._close = True
# Kill / cleanup all workers
for worker in self.workers:
process = worker.process
if process.sentinel in dead_sentinels:
process.join(JOIN_TIMEOUT_S)
if process.exitcode is not None and process.exitcode != 0:
logger.error("Worker %s pid %s died, exit code: %s",
process.name, process.pid, process.exitcode)
# for worker in self.workers:
# process = worker.process
# if process.sentinel in dead_sentinels:
# process.join(JOIN_TIMEOUT_S)
# if process.exitcode is not None and process.exitcode != 0:
# logger.error("Worker %s pid %s died, exit code: %s",
# process.name, process.pid, process.exitcode)
if not sys.is_finalizing():
# Kill / cleanup all workers
died_count = 0
for worker in self.workers:
process = worker.process
if process.sentinel in dead_sentinels:
process.join(JOIN_TIMEOUT_S)
if process.exitcode is not None and process.exitcode != 0:
died_count += 1
logger.error("Worker %s pid %s died, exit code: %s",
process.name, process.pid,
process.exitcode)
if died_count < len(self.workers):
logger.info(
"Killing remaining local vLLM worker processes")
# Cleanup any remaining workers
logger.info("Killing local vLLM worker processes")
# logger.info("Killing local vLLM worker processes")
for worker in self.workers:
worker.kill_worker()
# Must be done after worker task queues are all closed
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment