Unverified Commit d3e67deb authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Fix redundant kernel in sink dtype conversion (#8966)

parent 442534aa
...@@ -247,7 +247,7 @@ class GptOssAttention(nn.Module): ...@@ -247,7 +247,7 @@ class GptOssAttention(nn.Module):
) )
self.sinks = nn.Parameter( self.sinks = nn.Parameter(
torch.empty(self.num_heads, dtype=params_dtype), requires_grad=False torch.empty(self.num_heads, dtype=torch.float32), requires_grad=False
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
...@@ -301,7 +301,7 @@ class GptOssAttention(nn.Module): ...@@ -301,7 +301,7 @@ class GptOssAttention(nn.Module):
hidden_states, forward_batch, inner_state = intermediate_state hidden_states, forward_batch, inner_state = intermediate_state
if inner_state is None: if inner_state is None:
return hidden_states return hidden_states
attn_output = self.attn(*inner_state, sinks=self.sinks.to(torch.float32)) attn_output = self.attn(*inner_state, sinks=self.sinks)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment