Unverified Commit 36f6fc50 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

feat: enable ragged fa3 by default on hopper 12.4+ (#3442)

parent d8727275
...@@ -70,6 +70,8 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -70,6 +70,8 @@ class FlashInferAttnBackend(AttentionBackend):
): ):
super().__init__() super().__init__()
self.is_multimodal = model_runner.model_config.is_multimodal
# Parse constants # Parse constants
self.decode_use_tensor_cores = should_use_tensor_core( self.decode_use_tensor_cores = should_use_tensor_core(
kv_cache_dtype=model_runner.kv_cache_dtype, kv_cache_dtype=model_runner.kv_cache_dtype,
...@@ -130,12 +132,8 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -130,12 +132,8 @@ class FlashInferAttnBackend(AttentionBackend):
for _ in range(self.num_wrappers) for _ in range(self.num_wrappers)
] ]
# Create wrappers self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
# NOTE: we do not use ragged attention when there are multiple wrappers self.workspace_buffer, "NHD"
self.prefill_wrapper_ragged = (
BatchPrefillWithRaggedKVCacheWrapper(self.workspace_buffer, "NHD")
if self.num_wrappers == 1
else None
) )
# Two wrappers: one for sliding window attention and one for full attention. # Two wrappers: one for sliding window attention and one for full attention.
...@@ -217,13 +215,12 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -217,13 +215,12 @@ class FlashInferAttnBackend(AttentionBackend):
else: else:
prefix_lens = forward_batch.extend_prefix_lens prefix_lens = forward_batch.extend_prefix_lens
# Some heuristics to check whether to use ragged forward if self.is_multimodal:
if forward_batch.extend_num_tokens >= 4096 and self.num_wrappers == 1:
use_ragged = True
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
else:
use_ragged = False use_ragged = False
extend_no_prefix = False extend_no_prefix = False
else:
use_ragged = True
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
self.indices_updater_prefill.update( self.indices_updater_prefill.update(
forward_batch.req_pool_indices, forward_batch.req_pool_indices,
...@@ -640,7 +637,6 @@ class FlashInferIndicesUpdaterDecode: ...@@ -640,7 +637,6 @@ class FlashInferIndicesUpdaterDecode:
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
bs = kv_indptr.shape[0] - 1 bs = kv_indptr.shape[0] - 1
wrapper.end_forward()
wrapper.begin_forward( wrapper.begin_forward(
kv_indptr, kv_indptr,
kv_indices, kv_indices,
...@@ -651,6 +647,7 @@ class FlashInferIndicesUpdaterDecode: ...@@ -651,6 +647,7 @@ class FlashInferIndicesUpdaterDecode:
1, 1,
data_type=self.data_type, data_type=self.data_type,
q_data_type=self.q_data_type, q_data_type=self.q_data_type,
non_blocking=True,
) )
...@@ -860,7 +857,6 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -860,7 +857,6 @@ class FlashInferIndicesUpdaterPrefill:
# extend part # extend part
if use_ragged: if use_ragged:
wrapper_ragged.end_forward()
wrapper_ragged.begin_forward( wrapper_ragged.begin_forward(
qo_indptr, qo_indptr,
qo_indptr, qo_indptr,
...@@ -871,7 +867,6 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -871,7 +867,6 @@ class FlashInferIndicesUpdaterPrefill:
) )
# cached part # cached part
wrapper_paged.end_forward()
wrapper_paged.begin_forward( wrapper_paged.begin_forward(
qo_indptr, qo_indptr,
kv_indptr, kv_indptr,
...@@ -883,6 +878,7 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -883,6 +878,7 @@ class FlashInferIndicesUpdaterPrefill:
1, 1,
q_data_type=self.q_data_type, q_data_type=self.q_data_type,
custom_mask=custom_mask, custom_mask=custom_mask,
non_blocking=True,
) )
...@@ -1125,6 +1121,7 @@ def fast_decode_plan( ...@@ -1125,6 +1121,7 @@ def fast_decode_plan(
sm_scale: Optional[float] = None, sm_scale: Optional[float] = None,
rope_scale: Optional[float] = None, rope_scale: Optional[float] = None,
rope_theta: Optional[float] = None, rope_theta: Optional[float] = None,
**kwargs,
) -> None: ) -> None:
"""A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend.""" """A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend."""
batch_size = len(last_page_len) batch_size = len(last_page_len)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment