Unverified Commit 6b7c2471 authored by Nicolas Castet's avatar Nicolas Castet Committed by GitHub
Browse files

Fix broken trtllm_mha attn backend with gpt-oss (#9161)

parent a027a9b4
...@@ -293,8 +293,12 @@ class GptOssAttention(nn.Module): ...@@ -293,8 +293,12 @@ class GptOssAttention(nn.Module):
prefix=add_prefix("qkv_proj", prefix), prefix=add_prefix("qkv_proj", prefix),
) )
# Choose dtype of sinks based on attention backend: trtllm_mha requires float32,
# others can use bfloat16
attn_backend = global_server_args_dict.get("attention_backend")
sinks_dtype = torch.float32 if attn_backend == "trtllm_mha" else torch.bfloat16
self.sinks = nn.Parameter( self.sinks = nn.Parameter(
torch.empty(self.num_heads, dtype=torch.bfloat16), requires_grad=False torch.empty(self.num_heads, dtype=sinks_dtype), requires_grad=False
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment