Unverified Commit 0ff6d1fc authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Support FA3 backend for gpt-oss (#9028)

parent 4a16a71c
...@@ -58,7 +58,7 @@ runtime_common = [ ...@@ -58,7 +58,7 @@ runtime_common = [
srt = [ srt = [
"sglang[runtime_common]", "sglang[runtime_common]",
"sgl-kernel==0.3.4", "sgl-kernel==0.3.4.post1",
"torch==2.8.0", "torch==2.8.0",
"torchaudio==2.8.0", "torchaudio==2.8.0",
"torchvision", "torchvision",
......
...@@ -629,6 +629,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -629,6 +629,7 @@ class FlashAttentionBackend(AttentionBackend):
# For multi-head latent attention # For multi-head latent attention
q_rope: Optional[torch.Tensor] = None, q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None, k_rope: Optional[torch.Tensor] = None,
sinks: Optional[torch.Tensor] = None,
): ):
if k is not None: if k is not None:
assert v is not None assert v is not None
...@@ -687,6 +688,11 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -687,6 +688,11 @@ class FlashAttentionBackend(AttentionBackend):
forward_batch.forward_mode.is_target_verify() and self.topk > 1 forward_batch.forward_mode.is_target_verify() and self.topk > 1
) )
# For fa3 interface version compatibility, we put new fields into conditional keyword args
kwargs = {}
if sinks is not None:
kwargs["sinks"] = sinks
# Get the appropriate page table based on whether we're using local attention # Get the appropriate page table based on whether we're using local attention
if use_local_attn: if use_local_attn:
local_metadata = metadata.local_attn_metadata local_metadata = metadata.local_attn_metadata
...@@ -737,6 +743,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -737,6 +743,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale=k_descale, k_descale=k_descale,
v_descale=v_descale, v_descale=v_descale,
return_softmax_lse=use_cascade_attn, return_softmax_lse=use_cascade_attn,
**kwargs,
) )
if use_cascade_attn: if use_cascade_attn:
...@@ -757,6 +764,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -757,6 +764,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale=k_descale, k_descale=k_descale,
v_descale=v_descale, v_descale=v_descale,
return_softmax_lse=True, return_softmax_lse=True,
**kwargs,
) )
o, _ = merge_state_v2_wrapper( o, _ = merge_state_v2_wrapper(
o, o,
...@@ -898,6 +906,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -898,6 +906,7 @@ class FlashAttentionBackend(AttentionBackend):
# For multi-head latent attention # For multi-head latent attention
q_rope: Optional[torch.Tensor] = None, q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None, k_rope: Optional[torch.Tensor] = None,
sinks: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if k is not None: if k is not None:
assert v is not None assert v is not None
...@@ -943,6 +952,11 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -943,6 +952,11 @@ class FlashAttentionBackend(AttentionBackend):
) )
causal = not layer.is_cross_attention causal = not layer.is_cross_attention
# For fa3 interface version compatibility, we put new fields into conditional keyword args
kwargs = {}
if sinks is not None:
kwargs["sinks"] = sinks
k_descale, v_descale = None, None k_descale, v_descale = None, None
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention # only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# has corresponding quantization method so that layer.k_scale is not None, # has corresponding quantization method so that layer.k_scale is not None,
...@@ -985,6 +999,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -985,6 +999,7 @@ class FlashAttentionBackend(AttentionBackend):
softcap=layer.logit_cap, softcap=layer.logit_cap,
k_descale=k_descale, k_descale=k_descale,
v_descale=v_descale, v_descale=v_descale,
**kwargs,
) )
elif use_local_attn: elif use_local_attn:
# Use chunked (local) attention batching for self-attention # Use chunked (local) attention batching for self-attention
...@@ -1003,6 +1018,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1003,6 +1018,7 @@ class FlashAttentionBackend(AttentionBackend):
softcap=layer.logit_cap, softcap=layer.logit_cap,
k_descale=k_descale, k_descale=k_descale,
v_descale=v_descale, v_descale=v_descale,
**kwargs,
) )
else: else:
page_table = metadata.page_table page_table = metadata.page_table
...@@ -1030,6 +1046,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1030,6 +1046,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale=k_descale, k_descale=k_descale,
v_descale=v_descale, v_descale=v_descale,
return_softmax_lse=use_cascade_attn, return_softmax_lse=use_cascade_attn,
**kwargs,
) )
if use_cascade_attn: if use_cascade_attn:
o, softmax_lse, *rest = result o, softmax_lse, *rest = result
...@@ -1050,6 +1067,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1050,6 +1067,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale=k_descale, k_descale=k_descale,
v_descale=v_descale, v_descale=v_descale,
return_softmax_lse=True, return_softmax_lse=True,
**kwargs,
) )
) )
o, _ = merge_state_v2( o, _ = merge_state_v2(
......
...@@ -294,7 +294,7 @@ class GptOssAttention(nn.Module): ...@@ -294,7 +294,7 @@ class GptOssAttention(nn.Module):
) )
self.sinks = nn.Parameter( self.sinks = nn.Parameter(
torch.empty(self.num_heads, dtype=torch.float32), requires_grad=False torch.empty(self.num_heads, dtype=torch.bfloat16), requires_grad=False
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
......
...@@ -2106,10 +2106,10 @@ class ServerArgs: ...@@ -2106,10 +2106,10 @@ class ServerArgs:
if model_arch in ["GptOssForCausalLM"]: if model_arch in ["GptOssForCausalLM"]:
if self.attention_backend is None: if self.attention_backend is None:
self.attention_backend = "triton" self.attention_backend = "triton"
assert self.attention_backend in [ supported_backends = ["triton", "trtllm_mha", "fa3"]
"triton", assert (
"trtllm_mha", self.attention_backend in supported_backends
], f"GptOssForCausalLM requires 'triton' or 'trtllm_mha' attention backend, but got {self.attention_backend}" ), f"GptOssForCausalLM requires one of {supported_backends} attention backend, but got '{self.attention_backend}'"
quantization_config = getattr(hf_config, "quantization_config", None) quantization_config = getattr(hf_config, "quantization_config", None)
is_mxfp4_quant_format = ( is_mxfp4_quant_format = (
quantization_config is not None quantization_config is not None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment