Commit 91c7bda5 authored by zhuwenwen's avatar zhuwenwen
Browse files

[fix]修复cudagraph和eager分段模式开启mla后报错问题

parent ef8f16f4
......@@ -888,6 +888,10 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
max_encoder_seq_len=max_encoder_seq_len)
batch_size = len(input_tokens)
if batch_size + cuda_graph_pad_size >= self.runner.enforce_eager_bs_threshould:
cuda_graph_pad_size = -1
if cuda_graph_pad_size != -1:
# If cuda graph can be used, pad tensors accordingly.
# See `capture_model` API for more details.
......@@ -1717,7 +1721,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
virtual_engine = model_input.virtual_engine
previous_hidden_states = kwargs.get("previous_hidden_states")
if prefill_meta is None and decode_meta.use_cuda_graph and \
model_input.input_tokens.shape[0] <= self.enforce_eager_bs_threshould:
model_input.input_tokens.shape[0] < self.enforce_eager_bs_threshould:
assert model_input.input_tokens is not None
graph_batch_size = model_input.input_tokens.shape[0]
model_executable = self.graph_runners[virtual_engine][
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment