"vscode:/vscode.git/clone" did not exist on "53959eeeb9c5411ecc37f0de90159e29f6310a49"
Unverified Commit cba0d8c3 authored by Stefan He's avatar Stefan He Committed by GitHub
Browse files

[Feature] Support deterministic inference with FA3 backend (#10651)

parent f1d78923
...@@ -355,6 +355,13 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -355,6 +355,13 @@ class FlashAttentionBackend(AttentionBackend):
self.sliding_window_size is not None and self.sliding_window_size > -1 self.sliding_window_size is not None and self.sliding_window_size > -1
) )
# If num_splits == 0, we use a heuristic to automatically determine the number of splits.
# We set nums splits to 1 if deterministic inference is enabled.
# See https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/ for more details.
self.num_splits = (
1 if model_runner.server_args.enable_deterministic_inference else 0
)
def init_forward_metadata(self, forward_batch: ForwardBatch): def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Initialize forward metadata hence all layers in the forward pass can reuse it.""" """Initialize forward metadata hence all layers in the forward pass can reuse it."""
metadata = FlashAttentionMetadata() metadata = FlashAttentionMetadata()
...@@ -776,6 +783,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -776,6 +783,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale=k_descale, k_descale=k_descale,
v_descale=v_descale, v_descale=v_descale,
return_softmax_lse=use_cascade_attn, return_softmax_lse=use_cascade_attn,
num_splits=self.num_splits,
**kwargs, **kwargs,
) )
...@@ -797,6 +805,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -797,6 +805,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale=k_descale, k_descale=k_descale,
v_descale=v_descale, v_descale=v_descale,
return_softmax_lse=True, return_softmax_lse=True,
num_splits=self.num_splits,
**kwargs, **kwargs,
) )
o, _ = merge_state_v2_wrapper( o, _ = merge_state_v2_wrapper(
...@@ -901,6 +910,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -901,6 +910,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale=k_descale, k_descale=k_descale,
v_descale=v_descale, v_descale=v_descale,
return_softmax_lse=use_cascade_attn, return_softmax_lse=use_cascade_attn,
num_splits=self.num_splits,
) )
if use_cascade_attn: if use_cascade_attn:
o, softmax_lse, *rest = result o, softmax_lse, *rest = result
...@@ -922,6 +932,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -922,6 +932,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale=k_descale, k_descale=k_descale,
v_descale=v_descale, v_descale=v_descale,
return_softmax_lse=True, return_softmax_lse=True,
num_splits=self.num_splits,
) )
) )
o, _ = merge_state_v2_wrapper( o, _ = merge_state_v2_wrapper(
...@@ -1042,6 +1053,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1042,6 +1053,7 @@ class FlashAttentionBackend(AttentionBackend):
softcap=layer.logit_cap, softcap=layer.logit_cap,
k_descale=k_descale, k_descale=k_descale,
v_descale=v_descale, v_descale=v_descale,
num_splits=self.num_splits,
**kwargs, **kwargs,
) )
elif use_local_attn: elif use_local_attn:
...@@ -1061,6 +1073,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1061,6 +1073,7 @@ class FlashAttentionBackend(AttentionBackend):
softcap=layer.logit_cap, softcap=layer.logit_cap,
k_descale=k_descale, k_descale=k_descale,
v_descale=v_descale, v_descale=v_descale,
num_splits=self.num_splits,
**kwargs, **kwargs,
) )
else: else:
...@@ -1089,6 +1102,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1089,6 +1102,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale=k_descale, k_descale=k_descale,
v_descale=v_descale, v_descale=v_descale,
return_softmax_lse=use_cascade_attn, return_softmax_lse=use_cascade_attn,
num_splits=self.num_splits,
**kwargs, **kwargs,
) )
if use_cascade_attn: if use_cascade_attn:
...@@ -1110,6 +1124,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1110,6 +1124,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale=k_descale, k_descale=k_descale,
v_descale=v_descale, v_descale=v_descale,
return_softmax_lse=True, return_softmax_lse=True,
num_splits=self.num_splits,
**kwargs, **kwargs,
) )
) )
...@@ -1165,6 +1180,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1165,6 +1180,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale=k_descale, k_descale=k_descale,
v_descale=v_descale, v_descale=v_descale,
return_softmax_lse=use_cascade_attn, # softmax_lse is needed for merge states return_softmax_lse=use_cascade_attn, # softmax_lse is needed for merge states
num_splits=self.num_splits,
) )
if use_cascade_attn: if use_cascade_attn:
o, softmax_lse, *rest = result o, softmax_lse, *rest = result
...@@ -1185,6 +1201,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1185,6 +1201,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale=k_descale, k_descale=k_descale,
v_descale=v_descale, v_descale=v_descale,
return_softmax_lse=True, return_softmax_lse=True,
num_splits=self.num_splits,
) )
o, _ = merge_state_v2( o, _ = merge_state_v2(
o, o,
......
...@@ -118,7 +118,7 @@ DISAGG_TRANSFER_BACKEND_CHOICES = ["mooncake", "nixl", "ascend", "fake"] ...@@ -118,7 +118,7 @@ DISAGG_TRANSFER_BACKEND_CHOICES = ["mooncake", "nixl", "ascend", "fake"]
GRAMMAR_BACKEND_CHOICES = ["xgrammar", "outlines", "llguidance", "none"] GRAMMAR_BACKEND_CHOICES = ["xgrammar", "outlines", "llguidance", "none"]
DETERMINISTIC_ATTENTION_BACKEND_CHOICES = ["flashinfer"] DETERMINISTIC_ATTENTION_BACKEND_CHOICES = ["flashinfer", "fa3"]
# Allow external code to add more choices # Allow external code to add more choices
...@@ -998,11 +998,13 @@ class ServerArgs: ...@@ -998,11 +998,13 @@ class ServerArgs:
"batch_invariant_ops is not installed. Please install it from https://github.com/thinking-machines-lab/batch_invariant_ops/." "batch_invariant_ops is not installed. Please install it from https://github.com/thinking-machines-lab/batch_invariant_ops/."
) )
# Check some settings # Currently, only FA3 supports radix cache. Support for other backends is in progress
self.disable_radix_cache = True if self.attention_backend != "fa3":
logger.warning( self.disable_radix_cache = True
"Currently radix cache is disabled for deterministic inference. It will be supported in the future." logger.warning(
) "Currently radix cache is disabled for deterministic inference. It will be supported in the future."
)
if self.attention_backend not in DETERMINISTIC_ATTENTION_BACKEND_CHOICES: if self.attention_backend not in DETERMINISTIC_ATTENTION_BACKEND_CHOICES:
raise ValueError( raise ValueError(
f"Currently only {DETERMINISTIC_ATTENTION_BACKEND_CHOICES} attention backends are supported for deterministic inference." f"Currently only {DETERMINISTIC_ATTENTION_BACKEND_CHOICES} attention backends are supported for deterministic inference."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment