Unverified Commit ba1b3db7 authored by donggyukimc's avatar donggyukimc Committed by GitHub
Browse files

fix wrong 'cls' masking for bigbird qa model output (#13143)

parent 7a26307e
...@@ -2987,6 +2987,7 @@ class BigBirdForQuestionAnswering(BigBirdPreTrainedModel): ...@@ -2987,6 +2987,7 @@ class BigBirdForQuestionAnswering(BigBirdPreTrainedModel):
if token_type_ids is None: if token_type_ids is None:
token_type_ids = (~logits_mask).long() token_type_ids = (~logits_mask).long()
logits_mask = logits_mask logits_mask = logits_mask
logits_mask[:, 0] = False
logits_mask.unsqueeze_(2) logits_mask.unsqueeze_(2)
outputs = self.bert( outputs = self.bert(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment