Commit c3527cfb authored by thomwolf's avatar thomwolf
Browse files

ignore SQuAD targets outside of seq_length

parent 1b99cdf7
...@@ -455,9 +455,15 @@ class BertForQuestionAnswering(nn.Module): ...@@ -455,9 +455,15 @@ class BertForQuestionAnswering(nn.Module):
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1)
if start_positions is not None and end_positions is not None: if start_positions is not None and end_positions is not None:
start_positions = start_positions.squeeze(-1) # If we are on multi-GPU, split add a dimension # If we are on multi-GPU, split add a dimension - if not this is a no-op
start_positions = start_positions.squeeze(-1)
end_positions = end_positions.squeeze(-1) end_positions = end_positions.squeeze(-1)
loss_fct = CrossEntropyLoss() # sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) + 1
start_positions.clamp_(0, ignored_index)
end_positions.clamp_(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions) start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions) end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2 total_loss = (start_loss + end_loss) / 2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment