Unverified Commit 854260ca authored by Matt's avatar Matt Committed by GitHub
Browse files

TF/Numpy variants for all DataCollator classes (#13105)



* Adding a TF variant of the DataCollatorForTokenClassification to get feedback

* Added a Numpy variant and a post_init check to fail early if a missing import is found

* Fixed call to Numpy variant

* Added a couple more of the collators

* Update src/transformers/data/data_collator.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Fixes, style pass, finished DataCollatorForSeqToSeq

* Added all the LanguageModeling DataCollators, except SOP and PermutationLanguageModeling

* Adding DataCollatorForPermutationLanguageModeling

* Style pass

* Add missing `__call__` for PLM

* Remove `post_init` checks for frameworks because the imports inside them were making us fail code quality checks

* Remove unused imports

* First attempt at some TF tests

* A second attempt to make any of those tests actually work

* TF tests, round three

* TF tests, round four

* TF tests, round five

* TF tests, all enabled!

* Style pass

* Merging tests into `test_data_collator.py`

* Merging tests into `test_data_collator.py`

* Fixing up test imports

* Fixing up test imports

* Trying shuffling the conditionals around

* Commenting out non-functional old tests

* Completed all tests for all three frameworks

* Style pass

* Fixed test typo

* Style pass

* Move standard `__call__` method to mixin

* Rearranged imports for `test_data_collator`

* Fix data collator typo "torch" -> "pt"

* Fixed the most embarrassingly obvious bug

* Update src/transformers/data/data_collator.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Renaming mixin

* Updating docs
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarDalton Walker <dalton_walker@icloud.com>
Co-authored-by: default avatarAndrew Romans <andrew.romans@hotmail.com>
parent 74b3344f
...@@ -54,18 +54,18 @@ DataCollatorForLanguageModeling ...@@ -54,18 +54,18 @@ DataCollatorForLanguageModeling
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.data.data_collator.DataCollatorForLanguageModeling .. autoclass:: transformers.data.data_collator.DataCollatorForLanguageModeling
:members: mask_tokens :members: numpy_mask_tokens, tf_mask_tokens, torch_mask_tokens
DataCollatorForWholeWordMask DataCollatorForWholeWordMask
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.data.data_collator.DataCollatorForWholeWordMask .. autoclass:: transformers.data.data_collator.DataCollatorForWholeWordMask
:members: mask_tokens :members: numpy_mask_tokens, tf_mask_tokens, torch_mask_tokens
DataCollatorForPermutationLanguageModeling DataCollatorForPermutationLanguageModeling
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.data.data_collator.DataCollatorForPermutationLanguageModeling .. autoclass:: transformers.data.data_collator.DataCollatorForPermutationLanguageModeling
:members: mask_tokens :members: numpy_mask_tokens, tf_mask_tokens, torch_mask_tokens
...@@ -81,6 +81,17 @@ _import_structure = { ...@@ -81,6 +81,17 @@ _import_structure = {
"xnli_processors", "xnli_processors",
"xnli_tasks_num_labels", "xnli_tasks_num_labels",
], ],
"data.data_collator": [
"DataCollator",
"DataCollatorForLanguageModeling",
"DataCollatorForPermutationLanguageModeling",
"DataCollatorForSeq2Seq",
"DataCollatorForSOP",
"DataCollatorForTokenClassification",
"DataCollatorForWholeWordMask",
"DataCollatorWithPadding",
"default_data_collator",
],
"feature_extraction_sequence_utils": ["BatchFeature", "SequenceFeatureExtractor"], "feature_extraction_sequence_utils": ["BatchFeature", "SequenceFeatureExtractor"],
"file_utils": [ "file_utils": [
"CONFIG_NAME", "CONFIG_NAME",
...@@ -460,17 +471,6 @@ else: ...@@ -460,17 +471,6 @@ else:
if is_torch_available(): if is_torch_available():
_import_structure["benchmark.benchmark"] = ["PyTorchBenchmark"] _import_structure["benchmark.benchmark"] = ["PyTorchBenchmark"]
_import_structure["benchmark.benchmark_args"] = ["PyTorchBenchmarkArguments"] _import_structure["benchmark.benchmark_args"] = ["PyTorchBenchmarkArguments"]
_import_structure["data.data_collator"] = [
"DataCollator",
"DataCollatorForLanguageModeling",
"DataCollatorForPermutationLanguageModeling",
"DataCollatorForSeq2Seq",
"DataCollatorForSOP",
"DataCollatorForTokenClassification",
"DataCollatorForWholeWordMask",
"DataCollatorWithPadding",
"default_data_collator",
]
_import_structure["data.datasets"] = [ _import_structure["data.datasets"] = [
"GlueDataset", "GlueDataset",
"GlueDataTrainingArguments", "GlueDataTrainingArguments",
...@@ -1830,6 +1830,17 @@ if TYPE_CHECKING: ...@@ -1830,6 +1830,17 @@ if TYPE_CHECKING:
xnli_processors, xnli_processors,
xnli_tasks_num_labels, xnli_tasks_num_labels,
) )
from .data.data_collator import (
DataCollator,
DataCollatorForLanguageModeling,
DataCollatorForPermutationLanguageModeling,
DataCollatorForSeq2Seq,
DataCollatorForSOP,
DataCollatorForTokenClassification,
DataCollatorForWholeWordMask,
DataCollatorWithPadding,
default_data_collator,
)
# Feature Extractor # Feature Extractor
from .feature_extraction_utils import BatchFeature, SequenceFeatureExtractor from .feature_extraction_utils import BatchFeature, SequenceFeatureExtractor
...@@ -2174,17 +2185,6 @@ if TYPE_CHECKING: ...@@ -2174,17 +2185,6 @@ if TYPE_CHECKING:
# Benchmarks # Benchmarks
from .benchmark.benchmark import PyTorchBenchmark from .benchmark.benchmark import PyTorchBenchmark
from .benchmark.benchmark_args import PyTorchBenchmarkArguments from .benchmark.benchmark_args import PyTorchBenchmarkArguments
from .data.data_collator import (
DataCollator,
DataCollatorForLanguageModeling,
DataCollatorForPermutationLanguageModeling,
DataCollatorForSeq2Seq,
DataCollatorForSOP,
DataCollatorForTokenClassification,
DataCollatorForWholeWordMask,
DataCollatorWithPadding,
default_data_collator,
)
from .data.datasets import ( from .data.datasets import (
GlueDataset, GlueDataset,
GlueDataTrainingArguments, GlueDataTrainingArguments,
......
This diff is collapsed.
...@@ -12,62 +12,6 @@ class PyTorchBenchmarkArguments: ...@@ -12,62 +12,6 @@ class PyTorchBenchmarkArguments:
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class DataCollator:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class DataCollatorForLanguageModeling:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class DataCollatorForPermutationLanguageModeling:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class DataCollatorForSeq2Seq:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class DataCollatorForSOP:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class DataCollatorForTokenClassification:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class DataCollatorForWholeWordMask:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class DataCollatorWithPadding:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
def default_data_collator(*args, **kwargs):
requires_backends(default_data_collator, ["torch"])
class GlueDataset: class GlueDataset:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment