Unverified Commit 4dd5cf22 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix argument label (#4792)

* Fix argument label

* Fix test
parent 3723f30a
...@@ -91,7 +91,7 @@ class DataCollatorForLanguageModeling(DataCollator): ...@@ -91,7 +91,7 @@ class DataCollatorForLanguageModeling(DataCollator):
batch = self._tensorize_batch(examples) batch = self._tensorize_batch(examples)
if self.mlm: if self.mlm:
inputs, labels = self.mask_tokens(batch) inputs, labels = self.mask_tokens(batch)
return {"input_ids": inputs, "masked_lm_labels": labels} return {"input_ids": inputs, "labels": labels}
else: else:
return {"input_ids": batch, "labels": batch} return {"input_ids": batch, "labels": batch}
......
...@@ -74,14 +74,14 @@ class DataCollatorIntegrationTest(unittest.TestCase): ...@@ -74,14 +74,14 @@ class DataCollatorIntegrationTest(unittest.TestCase):
batch = data_collator.collate_batch(examples) batch = data_collator.collate_batch(examples)
self.assertIsInstance(batch, dict) self.assertIsInstance(batch, dict)
self.assertEqual(batch["input_ids"].shape, torch.Size((31, 107))) self.assertEqual(batch["input_ids"].shape, torch.Size((31, 107)))
self.assertEqual(batch["masked_lm_labels"].shape, torch.Size((31, 107))) self.assertEqual(batch["labels"].shape, torch.Size((31, 107)))
dataset = TextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512, overwrite_cache=True) dataset = TextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512, overwrite_cache=True)
examples = [dataset[i] for i in range(len(dataset))] examples = [dataset[i] for i in range(len(dataset))]
batch = data_collator.collate_batch(examples) batch = data_collator.collate_batch(examples)
self.assertIsInstance(batch, dict) self.assertIsInstance(batch, dict)
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512))) self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512)))
self.assertEqual(batch["masked_lm_labels"].shape, torch.Size((2, 512))) self.assertEqual(batch["labels"].shape, torch.Size((2, 512)))
@require_torch @require_torch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment