Unverified Commit b30372cb authored by Jialin Ouyang's avatar Jialin Ouyang Committed by GitHub
Browse files

[Perf] Move gc.freeze logic from EngineCoreProc to EngineCore for better coverage (#27896)


Signed-off-by: default avatarJialin Ouyang <Jialin.Ouyang@gmail.com>
parent d17ecc6b
...@@ -19,7 +19,6 @@ On the client side, run: ...@@ -19,7 +19,6 @@ On the client side, run:
import argparse import argparse
import asyncio import asyncio
import contextlib import contextlib
import gc
import importlib.util import importlib.util
import json import json
import os import os
...@@ -49,6 +48,7 @@ from vllm.benchmarks.lib.endpoint_request_func import ( ...@@ -49,6 +48,7 @@ from vllm.benchmarks.lib.endpoint_request_func import (
from vllm.benchmarks.lib.ready_checker import wait_for_endpoint from vllm.benchmarks.lib.ready_checker import wait_for_endpoint
from vllm.benchmarks.lib.utils import convert_to_pytorch_benchmark_format, write_to_json from vllm.benchmarks.lib.utils import convert_to_pytorch_benchmark_format, write_to_json
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils.gc_utils import freeze_gc_heap
MILLISECONDS_TO_SECONDS_CONVERSION = 1000 MILLISECONDS_TO_SECONDS_CONVERSION = 1000
...@@ -1414,8 +1414,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: ...@@ -1414,8 +1414,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
percentile_metrics: str = args.percentile_metrics or default_percentile_metrics percentile_metrics: str = args.percentile_metrics or default_percentile_metrics
# Avoid GC processing "static" data - reduce pause times. # Avoid GC processing "static" data - reduce pause times.
gc.collect() freeze_gc_heap()
gc.freeze()
benchmark_result = await benchmark( benchmark_result = await benchmark(
task_type=task_type, task_type=task_type,
......
...@@ -1483,6 +1483,9 @@ def destroy_distributed_environment(): ...@@ -1483,6 +1483,9 @@ def destroy_distributed_environment():
def cleanup_dist_env_and_memory(shutdown_ray: bool = False): def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
# Ensure all objects are not freezed before cleanup
gc.unfreeze()
destroy_model_parallel() destroy_model_parallel()
destroy_distributed_environment() destroy_distributed_environment()
if shutdown_ray: if shutdown_ray:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio import asyncio
import gc
import hashlib import hashlib
import importlib import importlib
import inspect import inspect
...@@ -118,6 +116,7 @@ from vllm.reasoning import ReasoningParserManager ...@@ -118,6 +116,7 @@ from vllm.reasoning import ReasoningParserManager
from vllm.tasks import POOLING_TASKS from vllm.tasks import POOLING_TASKS
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.gc_utils import freeze_gc_heap
from vllm.utils.network_utils import is_valid_ipv6_address from vllm.utils.network_utils import is_valid_ipv6_address
from vllm.utils.system_utils import decorate_logs, set_ulimit from vllm.utils.system_utils import decorate_logs, set_ulimit
from vllm.v1.engine.exceptions import EngineDeadError from vllm.v1.engine.exceptions import EngineDeadError
...@@ -153,8 +152,7 @@ async def lifespan(app: FastAPI): ...@@ -153,8 +152,7 @@ async def lifespan(app: FastAPI):
# Mark the startup heap as static so that it's ignored by GC. # Mark the startup heap as static so that it's ignored by GC.
# Reduces pause times of oldest generation collections. # Reduces pause times of oldest generation collections.
gc.collect() freeze_gc_heap()
gc.freeze()
try: try:
yield yield
finally: finally:
......
...@@ -89,6 +89,21 @@ class GCDebugger: ...@@ -89,6 +89,21 @@ class GCDebugger:
) )
def freeze_gc_heap() -> None:
"""
Freeze all objects tracked by the garbage collector. It should be invoked
after server init / warmup, to reduce GC overhead from static objects
during serving time.
"""
# Ensure all static objects are pushed down to the oldest generation for
# freeze
gc.collect(0)
gc.collect(1)
gc.collect(2)
# Freeze all GC tracked objects
gc.freeze()
def maybe_attach_gc_debug_callback() -> None: def maybe_attach_gc_debug_callback() -> None:
""" """
Attached a callback for GC debug when VLLM_GC_DEBUG is enabled. Attached a callback for GC debug when VLLM_GC_DEBUG is enabled.
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import gc
import os import os
import queue import queue
import signal import signal
...@@ -27,7 +26,10 @@ from vllm.multimodal import MULTIMODAL_REGISTRY ...@@ -27,7 +26,10 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import engine_receiver_cache_from_config from vllm.multimodal.cache import engine_receiver_cache_from_config
from vllm.tasks import POOLING_TASKS, SupportedTask from vllm.tasks import POOLING_TASKS, SupportedTask
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
from vllm.utils.gc_utils import maybe_attach_gc_debug_callback from vllm.utils.gc_utils import (
freeze_gc_heap,
maybe_attach_gc_debug_callback,
)
from vllm.utils.hashing import get_hash_fn_by_name from vllm.utils.hashing import get_hash_fn_by_name
from vllm.utils.network_utils import make_zmq_socket from vllm.utils.network_utils import make_zmq_socket
from vllm.utils.system_utils import decorate_logs, set_process_title from vllm.utils.system_utils import decorate_logs, set_process_title
...@@ -197,6 +199,10 @@ class EngineCore: ...@@ -197,6 +199,10 @@ class EngineCore:
self.step if self.batch_queue is None else self.step_with_batch_queue self.step if self.batch_queue is None else self.step_with_batch_queue
) )
# Mark the startup heap as static so that it's ignored by GC.
# Reduces pause times of oldest generation collections.
freeze_gc_heap()
def _initialize_kv_caches( def _initialize_kv_caches(
self, vllm_config: VllmConfig self, vllm_config: VllmConfig
) -> tuple[int, int, KVCacheConfig]: ) -> tuple[int, int, KVCacheConfig]:
...@@ -651,11 +657,6 @@ class EngineCoreProc(EngineCore): ...@@ -651,11 +657,6 @@ class EngineCoreProc(EngineCore):
assert addresses.coordinator_input is not None assert addresses.coordinator_input is not None
logger.info("Waiting for READY message from DP Coordinator...") logger.info("Waiting for READY message from DP Coordinator...")
# Mark the startup heap as static so that it's ignored by GC.
# Reduces pause times of oldest generation collections.
gc.collect()
gc.freeze()
# If enable, attach GC debugger after static variable freeze. # If enable, attach GC debugger after static variable freeze.
maybe_attach_gc_debug_callback() maybe_attach_gc_debug_callback()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment