Unverified Commit d2370e1b authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Adding PaddingDataCollator (#6442)

* Data collator with padding

* Add type annotation

* Support tensors as well

* Add comment

* Fix for labels wrong shape

* Data collator with padding

* Add type annotation

* Support tensors as well

* Add comment

* Fix for labels wrong shape

* Remove changes rendered unnecessary
parent 96c3329f
......@@ -438,6 +438,7 @@ if is_torch_available():
DataCollator,
DataCollatorForLanguageModeling,
DataCollatorForPermutationLanguageModeling,
DataCollatorWithPadding,
)
from .data.datasets import (
GlueDataset,
......
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, NewType, Tuple, Union
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
import torch
from torch.nn.utils.rnn import pad_sequence
from ..tokenization_utils import PreTrainedTokenizer
from ..tokenization_utils_base import BatchEncoding
from ..tokenization_utils_base import BatchEncoding, PaddingStrategy
from ..tokenization_utils_fast import PreTrainedTokenizerFast
InputDataClass = NewType("InputDataClass", Any)
......@@ -66,6 +67,55 @@ def default_data_collator(features: List[InputDataClass]) -> Dict[str, torch.Ten
return batch
@dataclass
class DataCollatorWithPadding:
"""
Data collator that will dynamically pad the inputs received.
Args:
tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
The tokenizer used for encoding the data.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding
index) among:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
single sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
max_length (:obj:`int`, `optional`):
Maximum length of the returned list and optionally padding length (see above).
pad_to_multiple_of (:obj:`int`, `optional`):
If set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
>= 7.5 (Volta).
"""
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
padding: Union[bool, str, PaddingStrategy] = True
max_length: Optional[int] = None
pad_to_multiple_of: Optional[int] = None
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
batch = self.tokenizer.pad(
features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors="pt",
)
if "label" in batch:
batch["labels"] = batch["label"]
del batch["label"]
if "label_ids" in batch:
batch["labels"] = batch["label_ids"]
del batch["label_ids"]
return batch
@dataclass
class DataCollatorForLanguageModeling:
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment