Unverified Commit e5ce395a authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Fix draft decode max batch size (#3676)

parent f983213a
...@@ -1094,7 +1094,7 @@ class FlashInferMultiStepDraftBackend: ...@@ -1094,7 +1094,7 @@ class FlashInferMultiStepDraftBackend:
self.topk = topk self.topk = topk
self.speculative_num_steps = speculative_num_steps self.speculative_num_steps = speculative_num_steps
self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
max_bs = model_runner.req_to_token_pool.size max_bs = model_runner.req_to_token_pool.size * self.topk
self.kv_indptr = torch.zeros( self.kv_indptr = torch.zeros(
( (
self.speculative_num_steps, self.speculative_num_steps,
......
...@@ -474,7 +474,7 @@ class TritonMultiStepDraftBackend: ...@@ -474,7 +474,7 @@ class TritonMultiStepDraftBackend:
self.topk = topk self.topk = topk
self.speculative_num_steps = speculative_num_steps self.speculative_num_steps = speculative_num_steps
self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
max_bs = model_runner.req_to_token_pool.size max_bs = model_runner.req_to_token_pool.size * self.topk
self.kv_indptr = torch.zeros( self.kv_indptr = torch.zeros(
( (
self.speculative_num_steps, self.speculative_num_steps,
......
...@@ -635,6 +635,9 @@ def decode_attention_fwd( ...@@ -635,6 +635,9 @@ def decode_attention_fwd(
logit_cap=0.0, logit_cap=0.0,
): ):
assert num_kv_splits == attn_logits.shape[2] assert num_kv_splits == attn_logits.shape[2]
assert q.shape[0] <= kv_indptr.shape[0] - 1
assert q.shape[0] <= attn_logits.shape[0]
kv_group_num = q.shape[1] // v_buffer.shape[1] kv_group_num = q.shape[1] // v_buffer.shape[1]
if kv_group_num == 1: if kv_group_num == 1:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment