Unverified Commit cba0d8c3 authored by Stefan He's avatar Stefan He Committed by GitHub
Browse files

[Feature] Support deterministic inference with FA3 backend (#10651)

parent f1d78923
...@@ -355,6 +355,13 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -355,6 +355,13 @@ class FlashAttentionBackend(AttentionBackend):
self.sliding_window_size is not None and self.sliding_window_size > -1 self.sliding_window_size is not None and self.sliding_window_size > -1
) )
# If num_splits == 0, we use a heuristic to automatically determine the number of splits.
# We set nums splits to 1 if deterministic inference is enabled.
# See https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/ for more details.
self.num_splits = (
1 if model_runner.server_args.enable_deterministic_inference else 0
)
def init_forward_metadata(self, forward_batch: ForwardBatch): def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Initialize forward metadata hence all layers in the forward pass can reuse it.""" """Initialize forward metadata hence all layers in the forward pass can reuse it."""
metadata = FlashAttentionMetadata() metadata = FlashAttentionMetadata()
...@@ -776,6 +783,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -776,6 +783,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale=k_descale, k_descale=k_descale,
v_descale=v_descale, v_descale=v_descale,
return_softmax_lse=use_cascade_attn, return_softmax_lse=use_cascade_attn,
num_splits=self.num_splits,
**kwargs, **kwargs,
) )
...@@ -797,6 +805,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -797,6 +805,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale=k_descale, k_descale=k_descale,
v_descale=v_descale, v_descale=v_descale,
return_softmax_lse=True, return_softmax_lse=True,
num_splits=self.num_splits,
**kwargs, **kwargs,
) )
o, _ = merge_state_v2_wrapper( o, _ = merge_state_v2_wrapper(
...@@ -901,6 +910,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -901,6 +910,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale=k_descale, k_descale=k_descale,
v_descale=v_descale, v_descale=v_descale,
return_softmax_lse=use_cascade_attn, return_softmax_lse=use_cascade_attn,
num_splits=self.num_splits,
) )
if use_cascade_attn: if use_cascade_attn:
o, softmax_lse, *rest = result o, softmax_lse, *rest = result
...@@ -922,6 +932,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -922,6 +932,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale=k_descale, k_descale=k_descale,
v_descale=v_descale, v_descale=v_descale,
return_softmax_lse=True, return_softmax_lse=True,
num_splits=self.num_splits,
) )
) )
o, _ = merge_state_v2_wrapper( o, _ = merge_state_v2_wrapper(
...@@ -1042,6 +1053,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1042,6 +1053,7 @@ class FlashAttentionBackend(AttentionBackend):
softcap=layer.logit_cap, softcap=layer.logit_cap,
k_descale=k_descale, k_descale=k_descale,
v_descale=v_descale, v_descale=v_descale,
num_splits=self.num_splits,
**kwargs, **kwargs,
) )
elif use_local_attn: elif use_local_attn:
...@@ -1061,6 +1073,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1061,6 +1073,7 @@ class FlashAttentionBackend(AttentionBackend):
softcap=layer.logit_cap, softcap=layer.logit_cap,
k_descale=k_descale, k_descale=k_descale,
v_descale=v_descale, v_descale=v_descale,
num_splits=self.num_splits,
**kwargs, **kwargs,
) )
else: else:
...@@ -1089,6 +1102,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1089,6 +1102,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale=k_descale, k_descale=k_descale,
v_descale=v_descale, v_descale=v_descale,
return_softmax_lse=use_cascade_attn, return_softmax_lse=use_cascade_attn,
num_splits=self.num_splits,
**kwargs, **kwargs,
) )
if use_cascade_attn: if use_cascade_attn:
...@@ -1110,6 +1124,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1110,6 +1124,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale=k_descale, k_descale=k_descale,
v_descale=v_descale, v_descale=v_descale,
return_softmax_lse=True, return_softmax_lse=True,
num_splits=self.num_splits,
**kwargs, **kwargs,
) )
) )
...@@ -1165,6 +1180,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1165,6 +1180,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale=k_descale, k_descale=k_descale,
v_descale=v_descale, v_descale=v_descale,
return_softmax_lse=use_cascade_attn, # softmax_lse is needed for merge states return_softmax_lse=use_cascade_attn, # softmax_lse is needed for merge states
num_splits=self.num_splits,
) )
if use_cascade_attn: if use_cascade_attn:
o, softmax_lse, *rest = result o, softmax_lse, *rest = result
...@@ -1185,6 +1201,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1185,6 +1201,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale=k_descale, k_descale=k_descale,
v_descale=v_descale, v_descale=v_descale,
return_softmax_lse=True, return_softmax_lse=True,
num_splits=self.num_splits,
) )
o, _ = merge_state_v2( o, _ = merge_state_v2(
o, o,
......
...@@ -118,7 +118,7 @@ DISAGG_TRANSFER_BACKEND_CHOICES = ["mooncake", "nixl", "ascend", "fake"] ...@@ -118,7 +118,7 @@ DISAGG_TRANSFER_BACKEND_CHOICES = ["mooncake", "nixl", "ascend", "fake"]
GRAMMAR_BACKEND_CHOICES = ["xgrammar", "outlines", "llguidance", "none"] GRAMMAR_BACKEND_CHOICES = ["xgrammar", "outlines", "llguidance", "none"]
DETERMINISTIC_ATTENTION_BACKEND_CHOICES = ["flashinfer"] DETERMINISTIC_ATTENTION_BACKEND_CHOICES = ["flashinfer", "fa3"]
# Allow external code to add more choices # Allow external code to add more choices
...@@ -998,11 +998,13 @@ class ServerArgs: ...@@ -998,11 +998,13 @@ class ServerArgs:
"batch_invariant_ops is not installed. Please install it from https://github.com/thinking-machines-lab/batch_invariant_ops/." "batch_invariant_ops is not installed. Please install it from https://github.com/thinking-machines-lab/batch_invariant_ops/."
) )
# Check some settings # Currently, only FA3 supports radix cache. Support for other backends is in progress
if self.attention_backend != "fa3":
self.disable_radix_cache = True self.disable_radix_cache = True
logger.warning( logger.warning(
"Currently radix cache is disabled for deterministic inference. It will be supported in the future." "Currently radix cache is disabled for deterministic inference. It will be supported in the future."
) )
if self.attention_backend not in DETERMINISTIC_ATTENTION_BACKEND_CHOICES: if self.attention_backend not in DETERMINISTIC_ATTENTION_BACKEND_CHOICES:
raise ValueError( raise ValueError(
f"Currently only {DETERMINISTIC_ATTENTION_BACKEND_CHOICES} attention backends are supported for deterministic inference." f"Currently only {DETERMINISTIC_ATTENTION_BACKEND_CHOICES} attention backends are supported for deterministic inference."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment