Unverified Commit 0b418673 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

fix labels (#6213)

parent cedc547e
...@@ -87,7 +87,8 @@ class DataCollatorForLanguageModeling: ...@@ -87,7 +87,8 @@ class DataCollatorForLanguageModeling:
return {"input_ids": inputs, "labels": labels} return {"input_ids": inputs, "labels": labels}
else: else:
labels = batch.clone().detach() labels = batch.clone().detach()
labels[labels == self.tokenizer.pad_token_id] = -100 if self.tokenizer.pad_token_id is not None:
labels[labels == self.tokenizer.pad_token_id] = -100
return {"input_ids": batch, "labels": labels} return {"input_ids": batch, "labels": labels}
def _tensorize_batch(self, examples: List[torch.Tensor]) -> torch.Tensor: def _tensorize_batch(self, examples: List[torch.Tensor]) -> torch.Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment