"...git@developer.sourcefind.cn:change/sglang.git" did not exist on "432f2053ddfe545abddb6252520dc21f7ee2b410"
Unverified Commit ef0e9d80 authored by Shen's avatar Shen Committed by GitHub
Browse files

Update: ElectraDiscriminatorPredictions forward. (#5471)

`ElectraDiscriminatorPredictions.forward` should not need `attention_mask`.
parent 13a8588f
......@@ -133,7 +133,7 @@ class ElectraDiscriminatorPredictions(nn.Module):
self.dense_prediction = nn.Linear(config.hidden_size, 1)
self.config = config
def forward(self, discriminator_hidden_states, attention_mask):
def forward(self, discriminator_hidden_states):
hidden_states = self.dense(discriminator_hidden_states)
hidden_states = get_activation(self.config.hidden_act)(hidden_states)
logits = self.dense_prediction(hidden_states).squeeze()
......@@ -518,7 +518,7 @@ class ElectraForPreTraining(ElectraPreTrainedModel):
)
discriminator_sequence_output = discriminator_hidden_states[0]
logits = self.discriminator_predictions(discriminator_sequence_output, attention_mask)
logits = self.discriminator_predictions(discriminator_sequence_output)
output = (logits,)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment