Commit 5a120438 authored by zhuwenwen's avatar zhuwenwen
Browse files

[fix]修复开启mtp并且显存不足时发生的超出维度限制问题

parent e3d4e9a9
...@@ -207,6 +207,7 @@ from vllm.attention.backends.utils import get_mla_dims ...@@ -207,6 +207,7 @@ from vllm.attention.backends.utils import get_mla_dims
from vllm.attention.ops.merge_attn_states import merge_attn_states from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.attention.utils.fa_utils import get_flash_attn_version from vllm.attention.utils.fa_utils import get_flash_attn_version
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.config import get_current_vllm_config
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase, LinearBase,
UnquantizedLinearMethod) UnquantizedLinearMethod)
...@@ -403,6 +404,12 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -403,6 +404,12 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
# support for cudagraph spec docoding # support for cudagraph spec docoding
self.spec_decode_block_table_tensor = None self.spec_decode_block_table_tensor = None
self.spec_decode_seq_lens = None self.spec_decode_seq_lens = None
self.decode_token_num_threshold = 1
vllm_config = get_current_vllm_config()
speculative_config = vllm_config.speculative_config
if speculative_config and speculative_config.num_speculative_tokens > 1:
self.use_spec_decode = True
self.decode_token_num_threshold = 1 + speculative_config.num_speculative_tokens
def reorder_batch(self, input_batch: "InputBatch", def reorder_batch(self, input_batch: "InputBatch",
...@@ -437,7 +444,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -437,7 +444,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
num_computed_tokens = input_batch.num_computed_tokens_cpu[req_idx] num_computed_tokens = input_batch.num_computed_tokens_cpu[req_idx]
num_prompt_tokens = input_batch.num_prompt_tokens[req_idx] num_prompt_tokens = input_batch.num_prompt_tokens[req_idx]
self.num_scheduled_tokens_np[i] = num_tokens self.num_scheduled_tokens_np[i] = num_tokens
if num_computed_tokens < num_prompt_tokens: if num_computed_tokens < num_prompt_tokens or (num_tokens > self.decode_token_num_threshold):
prefills.append(i) prefills.append(i)
num_prefill_tokens += num_tokens num_prefill_tokens += num_tokens
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment