Commit e0c80c12 authored by 王敏's avatar 王敏
Browse files

[fix]修复cudagraph和eager分段模式下,开启mla后报错问题

parent f98420b4
......@@ -892,6 +892,10 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
max_encoder_seq_len=max_encoder_seq_len)
batch_size = len(input_tokens)
if batch_size + cuda_graph_pad_size >= self.runner.enforce_eager_bs_threshould:
cuda_graph_pad_size = -1
if cuda_graph_pad_size != -1:
# If cuda graph can be used, pad tensors accordingly.
# See `capture_model` API for more details.
......@@ -1709,7 +1713,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
# virtual engines share the same kv cache.
virtual_engine = model_input.virtual_engine
if prefill_meta is None and decode_meta.use_cuda_graph and \
model_input.input_tokens.shape[0] <= self.enforce_eager_bs_threshould:
model_input.input_tokens.shape[0] < self.enforce_eager_bs_threshould:
assert model_input.input_tokens is not None
graph_batch_size = model_input.input_tokens.shape[0]
model_executable = self.graph_runners[virtual_engine][
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment