Unverified Commit cba0d8c3 authored by Stefan He's avatar Stefan He Committed by GitHub
Browse files

[Feature] Support deterministic inference with FA3 backend (#10651)

parent f1d78923
......@@ -355,6 +355,13 @@ class FlashAttentionBackend(AttentionBackend):
self.sliding_window_size is not None and self.sliding_window_size > -1
)
# If num_splits == 0, we use a heuristic to automatically determine the number of splits.
# We set nums splits to 1 if deterministic inference is enabled.
# See https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/ for more details.
self.num_splits = (
1 if model_runner.server_args.enable_deterministic_inference else 0
)
def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Initialize forward metadata hence all layers in the forward pass can reuse it."""
metadata = FlashAttentionMetadata()
......@@ -776,6 +783,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=use_cascade_attn,
num_splits=self.num_splits,
**kwargs,
)
......@@ -797,6 +805,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=True,
num_splits=self.num_splits,
**kwargs,
)
o, _ = merge_state_v2_wrapper(
......@@ -901,6 +910,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=use_cascade_attn,
num_splits=self.num_splits,
)
if use_cascade_attn:
o, softmax_lse, *rest = result
......@@ -922,6 +932,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=True,
num_splits=self.num_splits,
)
)
o, _ = merge_state_v2_wrapper(
......@@ -1042,6 +1053,7 @@ class FlashAttentionBackend(AttentionBackend):
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
num_splits=self.num_splits,
**kwargs,
)
elif use_local_attn:
......@@ -1061,6 +1073,7 @@ class FlashAttentionBackend(AttentionBackend):
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
num_splits=self.num_splits,
**kwargs,
)
else:
......@@ -1089,6 +1102,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=use_cascade_attn,
num_splits=self.num_splits,
**kwargs,
)
if use_cascade_attn:
......@@ -1110,6 +1124,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=True,
num_splits=self.num_splits,
**kwargs,
)
)
......@@ -1165,6 +1180,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=use_cascade_attn, # softmax_lse is needed for merge states
num_splits=self.num_splits,
)
if use_cascade_attn:
o, softmax_lse, *rest = result
......@@ -1185,6 +1201,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=True,
num_splits=self.num_splits,
)
o, _ = merge_state_v2(
o,
......
......@@ -118,7 +118,7 @@ DISAGG_TRANSFER_BACKEND_CHOICES = ["mooncake", "nixl", "ascend", "fake"]
GRAMMAR_BACKEND_CHOICES = ["xgrammar", "outlines", "llguidance", "none"]
DETERMINISTIC_ATTENTION_BACKEND_CHOICES = ["flashinfer"]
DETERMINISTIC_ATTENTION_BACKEND_CHOICES = ["flashinfer", "fa3"]
# Allow external code to add more choices
......@@ -998,11 +998,13 @@ class ServerArgs:
"batch_invariant_ops is not installed. Please install it from https://github.com/thinking-machines-lab/batch_invariant_ops/."
)
# Check some settings
self.disable_radix_cache = True
logger.warning(
"Currently radix cache is disabled for deterministic inference. It will be supported in the future."
)
# Currently, only FA3 supports radix cache. Support for other backends is in progress
if self.attention_backend != "fa3":
self.disable_radix_cache = True
logger.warning(
"Currently radix cache is disabled for deterministic inference. It will be supported in the future."
)
if self.attention_backend not in DETERMINISTIC_ATTENTION_BACKEND_CHOICES:
raise ValueError(
f"Currently only {DETERMINISTIC_ATTENTION_BACKEND_CHOICES} attention backends are supported for deterministic inference."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment