"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "805a202e1a0387a3d24d579d426773abd156962b"
Unverified Commit ef0e9d80 authored by Shen's avatar Shen Committed by GitHub
Browse files

Update: ElectraDiscriminatorPredictions forward. (#5471)

`ElectraDiscriminatorPredictions.forward` should not need `attention_mask`.
parent 13a8588f
...@@ -133,7 +133,7 @@ class ElectraDiscriminatorPredictions(nn.Module): ...@@ -133,7 +133,7 @@ class ElectraDiscriminatorPredictions(nn.Module):
self.dense_prediction = nn.Linear(config.hidden_size, 1) self.dense_prediction = nn.Linear(config.hidden_size, 1)
self.config = config self.config = config
def forward(self, discriminator_hidden_states, attention_mask): def forward(self, discriminator_hidden_states):
hidden_states = self.dense(discriminator_hidden_states) hidden_states = self.dense(discriminator_hidden_states)
hidden_states = get_activation(self.config.hidden_act)(hidden_states) hidden_states = get_activation(self.config.hidden_act)(hidden_states)
logits = self.dense_prediction(hidden_states).squeeze() logits = self.dense_prediction(hidden_states).squeeze()
...@@ -518,7 +518,7 @@ class ElectraForPreTraining(ElectraPreTrainedModel): ...@@ -518,7 +518,7 @@ class ElectraForPreTraining(ElectraPreTrainedModel):
) )
discriminator_sequence_output = discriminator_hidden_states[0] discriminator_sequence_output = discriminator_hidden_states[0]
logits = self.discriminator_predictions(discriminator_sequence_output, attention_mask) logits = self.discriminator_predictions(discriminator_sequence_output)
output = (logits,) output = (logits,)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment