Unverified Commit d329c9b0 authored by Teven's avatar Teven Committed by GitHub
Browse files

Fixed DataCollatorForLanguageModeling not accepting lists of lists (#6685)

* Fixed DataCollatorForLanguageModeling + PermutationLanguageModeling not accepting lists of lists

* Update data_collator.py

* black was grumpy
parent 0a850d21
......@@ -128,7 +128,9 @@ class DataCollatorForLanguageModeling:
mlm: bool = True
mlm_probability: float = 0.15
def __call__(self, examples: List[Union[torch.Tensor, Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]:
def __call__(
self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]]
) -> Dict[str, torch.Tensor]:
if isinstance(examples[0], (dict, BatchEncoding)):
examples = [e["input_ids"] for e in examples]
batch = self._tensorize_batch(examples)
......@@ -141,7 +143,12 @@ class DataCollatorForLanguageModeling:
labels[labels == self.tokenizer.pad_token_id] = -100
return {"input_ids": batch, "labels": labels}
def _tensorize_batch(self, examples: List[torch.Tensor]) -> torch.Tensor:
def _tensorize_batch(
self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]]
) -> torch.Tensor:
# In order to accept both lists of lists and lists of Tensors
if isinstance(examples[0], (list, tuple)):
examples = [torch.Tensor(e) for e in examples]
length_of_first = examples[0].size(0)
are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
if are_tensors_same_length:
......@@ -202,14 +209,21 @@ class DataCollatorForPermutationLanguageModeling:
plm_probability: float = 1 / 6
max_span_length: int = 5 # maximum length of a span of masked tokens
def __call__(self, examples: List[Union[torch.Tensor, Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]:
def __call__(
self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]]
) -> Dict[str, torch.Tensor]:
if isinstance(examples[0], (dict, BatchEncoding)):
examples = [e["input_ids"] for e in examples]
batch = self._tensorize_batch(examples)
inputs, perm_mask, target_mapping, labels = self.mask_tokens(batch)
return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels}
def _tensorize_batch(self, examples: List[torch.Tensor]) -> torch.Tensor:
def _tensorize_batch(
self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]]
) -> torch.Tensor:
# In order to accept both lists of lists and lists of Tensors
if isinstance(examples[0], (list, tuple)):
examples = [torch.Tensor(e) for e in examples]
length_of_first = examples[0].size(0)
are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
if are_tensors_same_length:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment