Unverified Commit 000cceca authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[Bugfix gpt-oss] Fix float32 convert for flashinfer sink support (#23016)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
parent 68373d31
...@@ -308,6 +308,15 @@ class Attention(nn.Module): ...@@ -308,6 +308,15 @@ class Attention(nn.Module):
if hasattr(self.impl, "process_weights_after_loading"): if hasattr(self.impl, "process_weights_after_loading"):
self.impl.process_weights_after_loading(act_dtype) self.impl.process_weights_after_loading(act_dtype)
# FlashInfer requires attention sinks to be float32
if (self.backend == _Backend.FLASHINFER_VLLM_V1
and hasattr(self.impl, 'sinks')):
from vllm.v1.attention.backends.flashinfer import FlashInferImpl
assert isinstance(self.impl, FlashInferImpl)
if (self.impl.sinks is not None
and self.impl.sinks.dtype != torch.float32):
self.impl.sinks = self.impl.sinks.to(torch.float32)
def get_attn_backend(self) -> type[AttentionBackend]: def get_attn_backend(self) -> type[AttentionBackend]:
return self.attn_backend return self.attn_backend
......
...@@ -642,9 +642,6 @@ class FlashInferImpl(AttentionImpl): ...@@ -642,9 +642,6 @@ class FlashInferImpl(AttentionImpl):
f"heads in the layer. Expected {num_heads}, but got " f"heads in the layer. Expected {num_heads}, but got "
f"{sinks.shape[0]}." f"{sinks.shape[0]}."
) )
# Cast sinks to float32 if needed (FlashInfer requirement)
if sinks.dtype != torch.float32:
sinks = sinks.to(torch.float32)
self.sinks = sinks self.sinks = sinks
def forward( def forward(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment