Unverified Commit 73ec4340 authored by Matt's avatar Matt Committed by GitHub
Browse files

Make DefaultDataCollator importable from root (#14588)

* Make DefaultDataCollator importable from root

* Add documentation for DefaultDataCollator and add return_tensors argument to all class docstrings

* make style

* Add DefaultDataCollator to data_collator.rst

* Add DefaultDataCollator to data_collator.rst
parent 71b1bf7e
......@@ -29,6 +29,13 @@ Default data collator
.. autofunction:: transformers.data.data_collator.default_data_collator
DefaultDataCollator
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.data.data_collator.DefaultDataCollator
:members:
DataCollatorWithPadding
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
......@@ -92,6 +92,7 @@ _import_structure = {
"DataCollatorForTokenClassification",
"DataCollatorForWholeWordMask",
"DataCollatorWithPadding",
"DefaultDataCollator",
"default_data_collator",
],
"feature_extraction_sequence_utils": ["SequenceFeatureExtractor"],
......@@ -2087,6 +2088,7 @@ if TYPE_CHECKING:
DataCollatorForTokenClassification,
DataCollatorForWholeWordMask,
DataCollatorWithPadding,
DefaultDataCollator,
default_data_collator,
)
from .feature_extraction_sequence_utils import SequenceFeatureExtractor
......
......@@ -24,6 +24,7 @@ from .data_collator import (
DataCollatorForTokenClassification,
DataCollatorForWholeWordMask,
DataCollatorWithPadding,
DefaultDataCollator,
default_data_collator,
)
from .metrics import glue_compute_metrics, xnli_compute_metrics
......
......@@ -72,6 +72,24 @@ def default_data_collator(features: List[InputDataClass], return_tensors="pt") -
@dataclass
class DefaultDataCollator(DataCollatorMixin):
"""
Very simple data collator that simply collates batches of dict-like objects and performs special handling for
potential keys named:
- ``label``: handles a single value (int or float) per object
- ``label_ids``: handles a list of values per object
Does not do any additional preprocessing: property names of the input object will be used as corresponding inputs
to the model. See glue and ner for example of how it's useful.
This is an object (like other data collators) rather than a pure function like default_data_collator. This can be
helpful if you need to set a return_tensors value at initialization.
Args:
return_tensors (:obj:`str`):
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
"""
return_tensors: str = "pt"
def __call__(self, features: List[Dict[str, Any]], return_tensors=None) -> Dict[str, Any]:
......@@ -214,6 +232,8 @@ class DataCollatorWithPadding:
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
7.5 (Volta).
return_tensors (:obj:`str`):
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
"""
tokenizer: PreTrainedTokenizerBase
......@@ -266,6 +286,8 @@ class DataCollatorForTokenClassification(DataCollatorMixin):
7.5 (Volta).
label_pad_token_id (:obj:`int`, `optional`, defaults to -100):
The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions).
return_tensors (:obj:`str`):
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
"""
tokenizer: PreTrainedTokenizerBase
......@@ -519,6 +541,8 @@ class DataCollatorForSeq2Seq:
7.5 (Volta).
label_pad_token_id (:obj:`int`, `optional`, defaults to -100):
The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).
return_tensors (:obj:`str`):
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
"""
tokenizer: PreTrainedTokenizerBase
......@@ -591,6 +615,8 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
The probability with which to (randomly) mask tokens in the input, when :obj:`mlm` is set to :obj:`True`.
pad_to_multiple_of (:obj:`int`, `optional`):
If set will pad the sequence to a multiple of the provided value.
return_tensors (:obj:`str`):
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
.. note::
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment