Unverified Commit 63276b76 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix #7284 (#7289)

parent 8d464374
...@@ -434,13 +434,15 @@ class DataCollatorForNextSentencePrediction: ...@@ -434,13 +434,15 @@ class DataCollatorForNextSentencePrediction:
else: else:
input_ids = self._tensorize_batch(input_ids) input_ids = self._tensorize_batch(input_ids)
return { result = {
"input_ids": input_ids, "input_ids": input_ids,
"attention_mask": self._tensorize_batch(attention_masks), "attention_mask": self._tensorize_batch(attention_masks),
"token_type_ids": self._tensorize_batch(segment_ids), "token_type_ids": self._tensorize_batch(segment_ids),
"masked_lm_labels": mlm_labels if self.mlm else None,
"next_sentence_label": torch.tensor(nsp_labels), "next_sentence_label": torch.tensor(nsp_labels),
} }
if self.mlm:
result["masked_lm_labels"] = mlm_labels
return result
def _tensorize_batch(self, examples: List[torch.Tensor]) -> torch.Tensor: def _tensorize_batch(self, examples: List[torch.Tensor]) -> torch.Tensor:
length_of_first = examples[0].size(0) length_of_first = examples[0].size(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment