"tasks/git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "259062c2de0fe77de16f76d7cf575ec4890b1ef5"
Unverified Commit 77a257fc authored by Jonathan Chang's avatar Jonathan Chang Committed by GitHub
Browse files

Fix DataCollatorForWholeWordMask (#8379)

* Fix DataCollatorForWholeWordMask

* Replace all tensorize_batch in data_collator.py
parent 517eaf46
...@@ -315,7 +315,7 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling): ...@@ -315,7 +315,7 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
input_ids = examples input_ids = examples
examples = [{"input_ids": e} for e in examples] examples = [{"input_ids": e} for e in examples]
batch_input = self._tensorize_batch(input_ids) batch_input = _collate_batch(input_ids, self.tokenizer)
mask_labels = [] mask_labels = []
for e in examples: for e in examples:
...@@ -332,7 +332,7 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling): ...@@ -332,7 +332,7 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
if i in ref_pos: if i in ref_pos:
ref_tokens[i] = "##" + ref_tokens[i] ref_tokens[i] = "##" + ref_tokens[i]
mask_labels.append(self._whole_word_mask(ref_tokens)) mask_labels.append(self._whole_word_mask(ref_tokens))
batch_mask = self._tensorize_batch(mask_labels) batch_mask = _collate_batch(mask_labels, self.tokenizer)
inputs, labels = self.mask_tokens(batch_input, batch_mask) inputs, labels = self.mask_tokens(batch_input, batch_mask)
return {"input_ids": inputs, "labels": labels} return {"input_ids": inputs, "labels": labels}
...@@ -511,28 +511,10 @@ class DataCollatorForPermutationLanguageModeling: ...@@ -511,28 +511,10 @@ class DataCollatorForPermutationLanguageModeling:
) -> Dict[str, torch.Tensor]: ) -> Dict[str, torch.Tensor]:
if isinstance(examples[0], (dict, BatchEncoding)): if isinstance(examples[0], (dict, BatchEncoding)):
examples = [e["input_ids"] for e in examples] examples = [e["input_ids"] for e in examples]
batch = self._tensorize_batch(examples) batch = _collate_batch(examples, self.tokenizer)
inputs, perm_mask, target_mapping, labels = self.mask_tokens(batch) inputs, perm_mask, target_mapping, labels = self.mask_tokens(batch)
return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels} return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels}
def _tensorize_batch(
self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]]
) -> torch.Tensor:
# In order to accept both lists of lists and lists of Tensors
if isinstance(examples[0], (list, tuple)):
examples = [torch.Tensor(e) for e in examples]
length_of_first = examples[0].size(0)
are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
if are_tensors_same_length:
return torch.stack(examples, dim=0)
else:
if self.tokenizer._pad_token is None:
raise ValueError(
"You are attempting to pad samples but the tokenizer you are using"
f" ({self.tokenizer.__class__.__name__}) does not have one."
)
return pad_sequence(examples, batch_first=True, padding_value=self.tokenizer.pad_token_id)
def mask_tokens(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: def mask_tokens(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
The masked tokens to be predicted for a particular sequence are determined by the following algorithm: The masked tokens to be predicted for a particular sequence are determined by the following algorithm:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment