Unverified Commit d3e67deb authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Fix redundant kernel in sink dtype conversion (#8966)

parent 442534aa
......@@ -247,7 +247,7 @@ class GptOssAttention(nn.Module):
)
self.sinks = nn.Parameter(
torch.empty(self.num_heads, dtype=params_dtype), requires_grad=False
torch.empty(self.num_heads, dtype=torch.float32), requires_grad=False
)
self.o_proj = RowParallelLinear(
......@@ -301,7 +301,7 @@ class GptOssAttention(nn.Module):
hidden_states, forward_batch, inner_state = intermediate_state
if inner_state is None:
return hidden_states
attn_output = self.attn(*inner_state, sinks=self.sinks.to(torch.float32))
attn_output = self.attn(*inner_state, sinks=self.sinks)
output, _ = self.o_proj(attn_output)
return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment