Unverified Commit dc464a3d authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[BugFix] AssertionError: Do not capture num_reqs > max_num_reqs for uniform batch (#25505)


Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
parent 1210e4d9
...@@ -2828,7 +2828,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -2828,7 +2828,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def _dummy_run( def _dummy_run(
self, self,
num_tokens: int, num_tokens: int,
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, cudagraph_runtime_mode: Optional[CUDAGraphMode] = None,
force_attention: bool = False, force_attention: bool = False,
uniform_decode: bool = False, uniform_decode: bool = False,
allow_microbatching: bool = True, allow_microbatching: bool = True,
...@@ -2844,6 +2844,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -2844,6 +2844,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
Args: Args:
num_tokens: Number of tokens to run the dummy forward pass. num_tokens: Number of tokens to run the dummy forward pass.
cudagraph_runtime_mode: used to control the behavior. cudagraph_runtime_mode: used to control the behavior.
- if not set will determine the cudagraph mode based on using
the self.cudagraph_dispatcher.
- CUDAGraphMode.NONE: No cudagraph, for warm up and profile run - CUDAGraphMode.NONE: No cudagraph, for warm up and profile run
- CUDAGraphMode.PIECEWISE: Piecewise cudagraph. - CUDAGraphMode.PIECEWISE: Piecewise cudagraph.
- CUDAGraphMode.FULL: Full cudagraph, attention metadata is - CUDAGraphMode.FULL: Full cudagraph, attention metadata is
...@@ -2857,7 +2859,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -2857,7 +2859,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
(1 token) and prefill (multiple tokens) requests. (1 token) and prefill (multiple tokens) requests.
remove_lora: If False, dummy LoRAs are not destroyed after the run remove_lora: If False, dummy LoRAs are not destroyed after the run
""" """
assert cudagraph_runtime_mode in { assert cudagraph_runtime_mode is None or cudagraph_runtime_mode in {
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL
} }
...@@ -2899,10 +2901,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -2899,10 +2901,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
elif uniform_decode: elif uniform_decode:
assert not create_mixed_batch assert not create_mixed_batch
num_reqs = cdiv(num_tokens, max_query_len) num_reqs = cdiv(num_tokens, max_query_len)
assert num_reqs <= max_num_reqs, \
f"Do not capture num_reqs {num_reqs} > max_num_reqs " \
f"{max_num_reqs} for uniform batch. Num tokens: " \
f"{num_tokens}, max_query_len: {max_query_len}"
num_scheduled_tokens_list = [max_query_len] * num_reqs num_scheduled_tokens_list = [max_query_len] * num_reqs
if num_tokens % max_query_len != 0: if num_tokens % max_query_len != 0:
num_scheduled_tokens_list[-1] = num_tokens % max_query_len num_scheduled_tokens_list[-1] = num_tokens % max_query_len
...@@ -3043,18 +3041,20 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -3043,18 +3041,20 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
intermediate_tensors = self.sync_and_slice_intermediate_tensors( intermediate_tensors = self.sync_and_slice_intermediate_tensors(
num_tokens, None, False) num_tokens, None, False)
if cudagraph_runtime_mode == CUDAGraphMode.NONE:
batch_descriptor = None
else:
# filter out the valid batch descriptor # filter out the valid batch descriptor
_cg_mode, batch_descriptor = \ _cg_mode, batch_descriptor = self.cudagraph_dispatcher.dispatch(
self.cudagraph_dispatcher.dispatch(
BatchDescriptor(num_tokens=num_tokens, BatchDescriptor(num_tokens=num_tokens,
uniform_decode=uniform_decode)) uniform_decode=uniform_decode))
# sanity check if cudagraph_runtime_mode is not None:
assert cudagraph_runtime_mode == _cg_mode, ( # we allow forcing NONE when the dispatcher disagrees to support
# warm ups for cudagraph capture
assert cudagraph_runtime_mode == CUDAGraphMode.NONE or \
cudagraph_runtime_mode == _cg_mode, (
f"Cudagraph runtime mode mismatch at dummy_run. " f"Cudagraph runtime mode mismatch at dummy_run. "
f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}.") f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}.")
else:
cudagraph_runtime_mode = _cg_mode
if ubatch_slices is not None: if ubatch_slices is not None:
num_tokens = num_tokens // 2 num_tokens = num_tokens // 2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment