"...git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "2945abd703ee83fff02fbcad4806a506289e0105"
Unverified Commit 0a3d0e02 authored by Setu Shah's avatar Setu Shah Committed by GitHub
Browse files

Replace labels with -100 to skip loss calc (#4718)

parent 6894b486
...@@ -82,7 +82,9 @@ class DataCollatorForLanguageModeling: ...@@ -82,7 +82,9 @@ class DataCollatorForLanguageModeling:
inputs, labels = self.mask_tokens(batch) inputs, labels = self.mask_tokens(batch)
return {"input_ids": inputs, "labels": labels} return {"input_ids": inputs, "labels": labels}
else: else:
return {"input_ids": batch, "labels": batch} labels = batch.clone().detach()
labels[labels == self.tokenizer.pad_token_id] = -100
return {"input_ids": batch, "labels": labels}
def _tensorize_batch(self, examples: List[torch.Tensor]) -> torch.Tensor: def _tensorize_batch(self, examples: List[torch.Tensor]) -> torch.Tensor:
length_of_first = examples[0].size(0) length_of_first = examples[0].size(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment