"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "a5847619e3e171bd711c0a193513d3b853d7d48a"
Unverified Commit 0206efb4 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Make all data collators accept dict (#6065)

* Make all data collators accept dict

* Style
parent 31a5486e
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Callable, Dict, List, NewType, Tuple from typing import Any, Callable, Dict, List, NewType, Tuple, Union
import torch import torch
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
...@@ -77,7 +77,9 @@ class DataCollatorForLanguageModeling: ...@@ -77,7 +77,9 @@ class DataCollatorForLanguageModeling:
mlm: bool = True mlm: bool = True
mlm_probability: float = 0.15 mlm_probability: float = 0.15
def __call__(self, examples: List[torch.Tensor]) -> Dict[str, torch.Tensor]: def __call__(self, examples: List[Union[torch.Tensor, Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]:
if isinstance(examples[0], dict):
examples = [e["input_ids"] for e in examples]
batch = self._tensorize_batch(examples) batch = self._tensorize_batch(examples)
if self.mlm: if self.mlm:
inputs, labels = self.mask_tokens(batch) inputs, labels = self.mask_tokens(batch)
...@@ -148,7 +150,9 @@ class DataCollatorForPermutationLanguageModeling: ...@@ -148,7 +150,9 @@ class DataCollatorForPermutationLanguageModeling:
plm_probability: float = 1 / 6 plm_probability: float = 1 / 6
max_span_length: int = 5 # maximum length of a span of masked tokens max_span_length: int = 5 # maximum length of a span of masked tokens
def __call__(self, examples: List[torch.Tensor]) -> Dict[str, torch.Tensor]: def __call__(self, examples: List[Union[torch.Tensor, Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]:
if isinstance(examples[0], dict):
examples = [e["input_ids"] for e in examples]
batch = self._tensorize_batch(examples) batch = self._tensorize_batch(examples)
inputs, perm_mask, target_mapping, labels = self.mask_tokens(batch) inputs, perm_mask, target_mapping, labels = self.mask_tokens(batch)
return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels} return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment