Unverified Commit 07df5578 authored by Aktsvigun's avatar Aktsvigun Committed by GitHub
Browse files

pad_to_multiple_of added to DataCollatorForWholeWordMask (#12999)



* pad_to_multiple_of added to DataCollatorForWholeWordMask

* pad_to_multiple_of added to DataCollatorForWholeWordMask
Co-authored-by: default avatarЦвигун Аким Олегович <AOTsvigun@sberbank.ru>
parent 3f44a66c
...@@ -418,7 +418,7 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling): ...@@ -418,7 +418,7 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
input_ids = examples input_ids = examples
examples = [{"input_ids": e} for e in examples] examples = [{"input_ids": e} for e in examples]
batch_input = _collate_batch(input_ids, self.tokenizer) batch_input = _collate_batch(input_ids, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
mask_labels = [] mask_labels = []
for e in examples: for e in examples:
...@@ -435,7 +435,7 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling): ...@@ -435,7 +435,7 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
if i in ref_pos: if i in ref_pos:
ref_tokens[i] = "##" + ref_tokens[i] ref_tokens[i] = "##" + ref_tokens[i]
mask_labels.append(self._whole_word_mask(ref_tokens)) mask_labels.append(self._whole_word_mask(ref_tokens))
batch_mask = _collate_batch(mask_labels, self.tokenizer) batch_mask = _collate_batch(mask_labels, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
inputs, labels = self.mask_tokens(batch_input, batch_mask) inputs, labels = self.mask_tokens(batch_input, batch_mask)
return {"input_ids": inputs, "labels": labels} return {"input_ids": inputs, "labels": labels}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment