Unverified Commit 4a53e8e9 authored by Jonathan Chang's avatar Jonathan Chang Committed by GitHub
Browse files

Fix DataCollatorForWholeWordMask again (#8397)

parent 61073099
...@@ -206,6 +206,10 @@ def _collate_batch(examples, tokenizer): ...@@ -206,6 +206,10 @@ def _collate_batch(examples, tokenizer):
return result return result
def tolist(x: Union[List[Any], torch.Tensor]):
return x.tolist() if isinstance(x, torch.Tensor) else x
@dataclass @dataclass
class DataCollatorForLanguageModeling: class DataCollatorForLanguageModeling:
""" """
...@@ -320,13 +324,13 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling): ...@@ -320,13 +324,13 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
mask_labels = [] mask_labels = []
for e in examples: for e in examples:
ref_tokens = [] ref_tokens = []
for id in e["input_ids"].tolist(): for id in tolist(e["input_ids"]):
token = self.tokenizer._convert_id_to_token(id) token = self.tokenizer._convert_id_to_token(id)
ref_tokens.append(token) ref_tokens.append(token)
# For Chinese tokens, we need extra inf to mark sub-word, e.g [喜,欢]-> [喜,##欢] # For Chinese tokens, we need extra inf to mark sub-word, e.g [喜,欢]-> [喜,##欢]
if "chinese_ref" in e: if "chinese_ref" in e:
ref_pos = e["chinese_ref"].tolist() ref_pos = tolist(e["chinese_ref"])
len_seq = e["input_ids"].size(0) len_seq = e["input_ids"].size(0)
for i in range(len_seq): for i in range(len_seq):
if i in ref_pos: if i in ref_pos:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment