Unverified Commit 25b8f9a8 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix FlaxRoFormerClassificationHead activation (#16168)



* fix activation
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 03c14a51
...@@ -594,12 +594,13 @@ class FlaxRoFormerClassificationHead(nn.Module): ...@@ -594,12 +594,13 @@ class FlaxRoFormerClassificationHead(nn.Module):
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range), kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
) )
self.activation = ACT2FN[self.config.hidden_act]
def __call__(self, hidden_states, deterministic=True): def __call__(self, hidden_states, deterministic=True):
hidden_states = hidden_states[:, 0, :] # take <s> token (equiv. to [CLS]) hidden_states = hidden_states[:, 0, :] # take <s> token (equiv. to [CLS])
hidden_states = self.dropout(hidden_states, deterministic=deterministic) hidden_states = self.dropout(hidden_states, deterministic=deterministic)
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = nn.tanh(hidden_states) hidden_states = self.activation(hidden_states)
hidden_states = self.dropout(hidden_states, deterministic=deterministic) hidden_states = self.dropout(hidden_states, deterministic=deterministic)
hidden_states = self.out_proj(hidden_states) hidden_states = self.out_proj(hidden_states)
return hidden_states return hidden_states
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment