"csrc/vscode:/vscode.git/clone" did not exist on "7d4d742bf03f8e1707130391e0b39bd6d93a702a"
Unverified Commit 36f6fc50 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

feat: enable ragged fa3 by default on hopper 12.4+ (#3442)

parent d8727275
......@@ -70,6 +70,8 @@ class FlashInferAttnBackend(AttentionBackend):
):
super().__init__()
self.is_multimodal = model_runner.model_config.is_multimodal
# Parse constants
self.decode_use_tensor_cores = should_use_tensor_core(
kv_cache_dtype=model_runner.kv_cache_dtype,
......@@ -130,12 +132,8 @@ class FlashInferAttnBackend(AttentionBackend):
for _ in range(self.num_wrappers)
]
# Create wrappers
# NOTE: we do not use ragged attention when there are multiple wrappers
self.prefill_wrapper_ragged = (
BatchPrefillWithRaggedKVCacheWrapper(self.workspace_buffer, "NHD")
if self.num_wrappers == 1
else None
self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
self.workspace_buffer, "NHD"
)
# Two wrappers: one for sliding window attention and one for full attention.
......@@ -217,13 +215,12 @@ class FlashInferAttnBackend(AttentionBackend):
else:
prefix_lens = forward_batch.extend_prefix_lens
# Some heuristics to check whether to use ragged forward
if forward_batch.extend_num_tokens >= 4096 and self.num_wrappers == 1:
use_ragged = True
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
else:
if self.is_multimodal:
use_ragged = False
extend_no_prefix = False
else:
use_ragged = True
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
self.indices_updater_prefill.update(
forward_batch.req_pool_indices,
......@@ -640,7 +637,6 @@ class FlashInferIndicesUpdaterDecode:
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
bs = kv_indptr.shape[0] - 1
wrapper.end_forward()
wrapper.begin_forward(
kv_indptr,
kv_indices,
......@@ -651,6 +647,7 @@ class FlashInferIndicesUpdaterDecode:
1,
data_type=self.data_type,
q_data_type=self.q_data_type,
non_blocking=True,
)
......@@ -860,7 +857,6 @@ class FlashInferIndicesUpdaterPrefill:
# extend part
if use_ragged:
wrapper_ragged.end_forward()
wrapper_ragged.begin_forward(
qo_indptr,
qo_indptr,
......@@ -871,7 +867,6 @@ class FlashInferIndicesUpdaterPrefill:
)
# cached part
wrapper_paged.end_forward()
wrapper_paged.begin_forward(
qo_indptr,
kv_indptr,
......@@ -883,6 +878,7 @@ class FlashInferIndicesUpdaterPrefill:
1,
q_data_type=self.q_data_type,
custom_mask=custom_mask,
non_blocking=True,
)
......@@ -1125,6 +1121,7 @@ def fast_decode_plan(
sm_scale: Optional[float] = None,
rope_scale: Optional[float] = None,
rope_theta: Optional[float] = None,
**kwargs,
) -> None:
"""A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend."""
batch_size = len(last_page_len)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment