Unverified Commit 854260ca authored by Matt's avatar Matt Committed by GitHub
Browse files

TF/Numpy variants for all DataCollator classes (#13105)



* Adding a TF variant of the DataCollatorForTokenClassification to get feedback

* Added a Numpy variant and a post_init check to fail early if a missing import is found

* Fixed call to Numpy variant

* Added a couple more of the collators

* Update src/transformers/data/data_collator.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Fixes, style pass, finished DataCollatorForSeqToSeq

* Added all the LanguageModeling DataCollators, except SOP and PermutationLanguageModeling

* Adding DataCollatorForPermutationLanguageModeling

* Style pass

* Add missing `__call__` for PLM

* Remove `post_init` checks for frameworks because the imports inside them were making us fail code quality checks

* Remove unused imports

* First attempt at some TF tests

* A second attempt to make any of those tests actually work

* TF tests, round three

* TF tests, round four

* TF tests, round five

* TF tests, all enabled!

* Style pass

* Merging tests into `test_data_collator.py`

* Merging tests into `test_data_collator.py`

* Fixing up test imports

* Fixing up test imports

* Trying shuffling the conditionals around

* Commenting out non-functional old tests

* Completed all tests for all three frameworks

* Style pass

* Fixed test typo

* Style pass

* Move standard `__call__` method to mixin

* Rearranged imports for `test_data_collator`

* Fix data collator typo "torch" -> "pt"

* Fixed the most embarrassingly obvious bug

* Update src/transformers/data/data_collator.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Renaming mixin

* Updating docs
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarDalton Walker <dalton_walker@icloud.com>
Co-authored-by: default avatarAndrew Romans <andrew.romans@hotmail.com>
parent 74b3344f
......@@ -54,18 +54,18 @@ DataCollatorForLanguageModeling
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.data.data_collator.DataCollatorForLanguageModeling
:members: mask_tokens
:members: numpy_mask_tokens, tf_mask_tokens, torch_mask_tokens
DataCollatorForWholeWordMask
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.data.data_collator.DataCollatorForWholeWordMask
:members: mask_tokens
:members: numpy_mask_tokens, tf_mask_tokens, torch_mask_tokens
DataCollatorForPermutationLanguageModeling
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.data.data_collator.DataCollatorForPermutationLanguageModeling
:members: mask_tokens
:members: numpy_mask_tokens, tf_mask_tokens, torch_mask_tokens
......@@ -81,6 +81,17 @@ _import_structure = {
"xnli_processors",
"xnli_tasks_num_labels",
],
"data.data_collator": [
"DataCollator",
"DataCollatorForLanguageModeling",
"DataCollatorForPermutationLanguageModeling",
"DataCollatorForSeq2Seq",
"DataCollatorForSOP",
"DataCollatorForTokenClassification",
"DataCollatorForWholeWordMask",
"DataCollatorWithPadding",
"default_data_collator",
],
"feature_extraction_sequence_utils": ["BatchFeature", "SequenceFeatureExtractor"],
"file_utils": [
"CONFIG_NAME",
......@@ -460,17 +471,6 @@ else:
if is_torch_available():
_import_structure["benchmark.benchmark"] = ["PyTorchBenchmark"]
_import_structure["benchmark.benchmark_args"] = ["PyTorchBenchmarkArguments"]
_import_structure["data.data_collator"] = [
"DataCollator",
"DataCollatorForLanguageModeling",
"DataCollatorForPermutationLanguageModeling",
"DataCollatorForSeq2Seq",
"DataCollatorForSOP",
"DataCollatorForTokenClassification",
"DataCollatorForWholeWordMask",
"DataCollatorWithPadding",
"default_data_collator",
]
_import_structure["data.datasets"] = [
"GlueDataset",
"GlueDataTrainingArguments",
......@@ -1830,6 +1830,17 @@ if TYPE_CHECKING:
xnli_processors,
xnli_tasks_num_labels,
)
from .data.data_collator import (
DataCollator,
DataCollatorForLanguageModeling,
DataCollatorForPermutationLanguageModeling,
DataCollatorForSeq2Seq,
DataCollatorForSOP,
DataCollatorForTokenClassification,
DataCollatorForWholeWordMask,
DataCollatorWithPadding,
default_data_collator,
)
# Feature Extractor
from .feature_extraction_utils import BatchFeature, SequenceFeatureExtractor
......@@ -2174,17 +2185,6 @@ if TYPE_CHECKING:
# Benchmarks
from .benchmark.benchmark import PyTorchBenchmark
from .benchmark.benchmark_args import PyTorchBenchmarkArguments
from .data.data_collator import (
DataCollator,
DataCollatorForLanguageModeling,
DataCollatorForPermutationLanguageModeling,
DataCollatorForSeq2Seq,
DataCollatorForSOP,
DataCollatorForTokenClassification,
DataCollatorForWholeWordMask,
DataCollatorWithPadding,
default_data_collator,
)
from .data.datasets import (
GlueDataset,
GlueDataTrainingArguments,
......
This diff is collapsed.
......@@ -12,62 +12,6 @@ class PyTorchBenchmarkArguments:
requires_backends(self, ["torch"])
class DataCollator:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class DataCollatorForLanguageModeling:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class DataCollatorForPermutationLanguageModeling:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class DataCollatorForSeq2Seq:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class DataCollatorForSOP:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class DataCollatorForTokenClassification:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class DataCollatorForWholeWordMask:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class DataCollatorWithPadding:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
def default_data_collator(*args, **kwargs):
requires_backends(default_data_collator, ["torch"])
class GlueDataset:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment