Commit 78d833ae authored by zhuwenwen's avatar zhuwenwen
Browse files

fix ds-v2 run error

parent 5b0a1c93
......@@ -103,6 +103,7 @@ def flash_mla_with_kvcache(
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla(
q,
k_cache,
None,
head_dim_v,
cache_seqlens,
block_table,
......@@ -117,6 +118,7 @@ def flash_mla_with_kvcache(
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla(
q,
k_cache,
None,
head_dim_v,
cache_seqlens,
block_table,
......@@ -129,6 +131,7 @@ def flash_mla_with_kvcache(
out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla(
q,
k_cache,
None,
head_dim_v,
cache_seqlens,
block_table,
......
......@@ -884,7 +884,6 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
use_int8_w8a16=use_int8_w8a16,
per_channel_quant=per_channel_quant,
HAS_BIAS=HAS_BIAS,
BLOCK_SIZE_K=BLOCK_SIZE_K,
**config,
)
......
......@@ -490,7 +490,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
self.block_table = block_table
self.use_spec_decode = False
self.num_scheduled_tokens_np = np.zeros(scheduler_config.max_num_seqs, dtype=np.int32)
# support for cudagraph spec docoding
self.spec_decode_block_table_tensor = None
......
......@@ -18,6 +18,7 @@ if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionImpl
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.v1.attention.backends.mla.common import MLACommonMetadataBuilder
import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend
......@@ -630,7 +631,7 @@ def reorder_batch_to_split_decodes_and_prefills(
input_batch: "InputBatch",
scheduler_output: "SchedulerOutput",
decode_threshold: int = 1,
num_scheduled_tokens_np: Optional[np.ndarray] = None
# num_scheduled_tokens_np: np.ndarray = np.zeros(256, dtype=np.int32),
) -> bool:
"""
Reorders the batch to split into prefill and decode requests; places all
......@@ -649,6 +650,9 @@ def reorder_batch_to_split_decodes_and_prefills(
prefills = []
num_decode_tokens = 0
num_prefill_tokens = 0
vllm_config = VllmConfig()
num_scheduled_tokens_np = np.zeros(vllm_config.scheduler_config.max_num_seqs, dtype=np.int32)
for i, req_id in enumerate(input_batch.req_ids):
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment