"vscode:/vscode.git/clone" did not exist on "80679f108ffd94c165ea11adbc3afcc43f24a06e"
Unverified Commit 4778b426 authored by Sage Moore's avatar Sage Moore Committed by GitHub
Browse files

Reduce the Cuda Graph memory footprint when running with DBO (#25779)


Signed-off-by: default avatarSage Moore <sage@neuralmagic.com>
parent c70ac4b8
......@@ -3477,8 +3477,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# We skip EPLB here since we don't want to record dummy metrics
for num_tokens in compilation_cases:
# We currently only capture ubatched graphs when its a FULL
# cudagraph and for uniform decode batches.
capture_ubatched_graph = self.parallel_config.enable_dbo \
# cudagraph, a uniform decode batch, and the number of tokens
# is above the threshold. Otherwise we just capture a non-ubatched
# version of the graph
allow_microbatching = self.parallel_config.enable_dbo \
and cudagraph_runtime_mode == CUDAGraphMode.FULL \
and uniform_decode \
and check_ubatch_thresholds(
......@@ -3487,17 +3489,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
uniform_decode=uniform_decode,
)
# Currently we capture both microbatched and non-microbatched
# graphs when capture_ubatched_graph is True, this is because
# occasionally we will be forced out of microbatching due to other
# DP ranks not microbatching (usually caused by an empty second
# microbatch; once we resolve this, we can remove the
# non-microbatched graph capture).
allow_microbatching_options = [True, False] if \
capture_ubatched_graph else [False]
for allow_microbatching in allow_microbatching_options:
for _ in range(
self.compilation_config.cudagraph_num_of_warmups):
for _ in range(self.compilation_config.cudagraph_num_of_warmups):
# Use CUDAGraphRuntimeStyle.NONE (default) for warmup.
# But be careful, warm up with `NONE`is orthogonal to
# if we want to warm up attention or not. This is
......
......@@ -330,6 +330,18 @@ class UBatchWrapper:
# If there's no ubatching, just run the runnable object
if ubatch_slices is None:
# This is to account for the case where ubatching was aborted.
# When we capture full graphs we only capture one graph per shape,
# meaning that if we have a ubatched cudagraph for the current
# num_tokens, we don't have a non-ubatched one. Without this
# check, the cudagraph wrapper will try to capture a cudagraph
# for this shape during a normal run.
if cudagraph_runtime_mode is CUDAGraphMode.FULL:
assert batch_descriptor is not None
if batch_descriptor.num_tokens in self.cudagraphs:
cudagraph_runtime_mode = CUDAGraphMode.NONE
if cudagraph_runtime_mode in (CUDAGraphMode.NONE,
CUDAGraphMode.PIECEWISE):
return self.runnable(*args, **kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment