Unverified Commit d23cb9a0 authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

[Eagle] reduce one draft forward (#3468)

parent 2d611323
...@@ -947,7 +947,7 @@ class FlashInferMultiStepDraftBackend: ...@@ -947,7 +947,7 @@ class FlashInferMultiStepDraftBackend:
triton.next_power_of_2(bs), triton.next_power_of_2(bs),
) )
for i in range(self.speculative_num_steps): for i in range(self.speculative_num_steps - 1):
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1] forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
forward_batch.spec_info.kv_indices = kv_indices_buffer[i][ forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
: seq_lens_sum * self.topk + bs * (i + 1) : seq_lens_sum * self.topk + bs * (i + 1)
......
...@@ -234,6 +234,10 @@ class EAGLEWorker(TpModelWorker): ...@@ -234,6 +234,10 @@ class EAGLEWorker(TpModelWorker):
token_list.append(tree_info[1]) token_list.append(tree_info[1])
parents_list.append(tree_info[2]) parents_list.append(tree_info[2])
# we don't need to run the last forward. we get 1 token from draft prefill and (#spec steps - 1) tokens here
if i == self.speculative_num_steps - 1:
break
# Set inputs # Set inputs
forward_batch.input_ids = input_ids forward_batch.input_ids = input_ids
forward_batch.out_cache_loc = out_cache_loc[ forward_batch.out_cache_loc = out_cache_loc[
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment