Unverified Commit 4acf6902 authored by Brayden Zhong's avatar Brayden Zhong Committed by GitHub
Browse files

[Optimization][Perf] Disable the GC during CUDA graph capture to speed up by up to 3x (#8577)

parent aee0ef52
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
from __future__ import annotations from __future__ import annotations
import bisect import bisect
import gc
import inspect import inspect
import logging import logging
import os import os
...@@ -75,6 +76,24 @@ def model_capture_mode(): ...@@ -75,6 +76,24 @@ def model_capture_mode():
is_capture_mode = False is_capture_mode = False
@contextmanager
def freeze_gc(enable_cudagraph_gc: bool):
"""
Optimize garbage collection during CUDA graph capture.
Clean up, then freeze all remaining objects from being included
in future collections if GC is disabled during capture.
"""
gc.collect()
should_freeze = not enable_cudagraph_gc
if should_freeze:
gc.freeze()
try:
yield
finally:
if should_freeze:
gc.unfreeze()
def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int): def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
for sub in model._modules.values(): for sub in model._modules.values():
if isinstance(sub, CustomOp): if isinstance(sub, CustomOp):
...@@ -423,7 +442,12 @@ class CudaGraphRunner: ...@@ -423,7 +442,12 @@ class CudaGraphRunner:
record_shapes=True, record_shapes=True,
) )
with graph_capture() as graph_capture_context: # Trigger CUDA graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes.
with freeze_gc(
self.model_runner.server_args.enable_cudagraph_gc
), graph_capture() as graph_capture_context:
with profile_context as prof: with profile_context as prof:
self.stream = graph_capture_context.stream self.stream = graph_capture_context.stream
avail_mem = get_available_gpu_memory( avail_mem = get_available_gpu_memory(
......
...@@ -215,6 +215,7 @@ class ServerArgs: ...@@ -215,6 +215,7 @@ class ServerArgs:
disable_cuda_graph: bool = False disable_cuda_graph: bool = False
disable_cuda_graph_padding: bool = False disable_cuda_graph_padding: bool = False
enable_profile_cuda_graph: bool = False enable_profile_cuda_graph: bool = False
enable_cudagraph_gc: bool = False
enable_nccl_nvls: bool = False enable_nccl_nvls: bool = False
enable_tokenizer_batch_encode: bool = False enable_tokenizer_batch_encode: bool = False
disable_outlines_disk_cache: bool = False disable_outlines_disk_cache: bool = False
...@@ -1545,6 +1546,11 @@ class ServerArgs: ...@@ -1545,6 +1546,11 @@ class ServerArgs:
action="store_true", action="store_true",
help="Enable profiling of cuda graph capture.", help="Enable profiling of cuda graph capture.",
) )
parser.add_argument(
"--enable-cudagraph-gc",
action="store_true",
help="Enable garbage collection during CUDA graph capture. If disabled (default), GC is frozen during capture to speed up the process.",
)
parser.add_argument( parser.add_argument(
"--enable-nccl-nvls", "--enable-nccl-nvls",
action="store_true", action="store_true",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment