Unverified Commit c4059ea5 authored by Antoni Baum's avatar Antoni Baum Committed by GitHub
Browse files

[Bugfix] Add explicit `end_forward` calls to flashinfer (#6044)

parent 8e0817c2
...@@ -126,6 +126,7 @@ class FlashInferMetadata(AttentionMetadata): ...@@ -126,6 +126,7 @@ class FlashInferMetadata(AttentionMetadata):
self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
self.device) self.device)
self.prefill_wrapper.end_forward()
self.prefill_wrapper.begin_forward( self.prefill_wrapper.begin_forward(
self.query_start_loc, self.paged_kv_indptr, self.query_start_loc, self.paged_kv_indptr,
self.paged_kv_indices, self.paged_kv_last_page_len, self.paged_kv_indices, self.paged_kv_last_page_len,
...@@ -142,6 +143,7 @@ class FlashInferMetadata(AttentionMetadata): ...@@ -142,6 +143,7 @@ class FlashInferMetadata(AttentionMetadata):
self.device) self.device)
assert self.decode_wrapper is not None assert self.decode_wrapper is not None
self.decode_wrapper.end_forward()
self.decode_wrapper.begin_forward( self.decode_wrapper.begin_forward(
self.paged_kv_indptr, self.paged_kv_indptr,
self.paged_kv_indices, self.paged_kv_indices,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment