Commit 91c7bda5 authored by zhuwenwen's avatar zhuwenwen
Browse files

[fix]修复cudagraph和eager分段模式开启mla后报错问题

parent ef8f16f4
...@@ -888,6 +888,10 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -888,6 +888,10 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
max_encoder_seq_len=max_encoder_seq_len) max_encoder_seq_len=max_encoder_seq_len)
batch_size = len(input_tokens) batch_size = len(input_tokens)
if batch_size + cuda_graph_pad_size >= self.runner.enforce_eager_bs_threshould:
cuda_graph_pad_size = -1
if cuda_graph_pad_size != -1: if cuda_graph_pad_size != -1:
# If cuda graph can be used, pad tensors accordingly. # If cuda graph can be used, pad tensors accordingly.
# See `capture_model` API for more details. # See `capture_model` API for more details.
...@@ -1717,7 +1721,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1717,7 +1721,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
virtual_engine = model_input.virtual_engine virtual_engine = model_input.virtual_engine
previous_hidden_states = kwargs.get("previous_hidden_states") previous_hidden_states = kwargs.get("previous_hidden_states")
if prefill_meta is None and decode_meta.use_cuda_graph and \ if prefill_meta is None and decode_meta.use_cuda_graph and \
model_input.input_tokens.shape[0] <= self.enforce_eager_bs_threshould: model_input.input_tokens.shape[0] < self.enforce_eager_bs_threshould:
assert model_input.input_tokens is not None assert model_input.input_tokens is not None
graph_batch_size = model_input.input_tokens.shape[0] graph_batch_size = model_input.input_tokens.shape[0]
model_executable = self.graph_runners[virtual_engine][ model_executable = self.graph_runners[virtual_engine][
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment