"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "efdd436663436e78d8ad3213d11325d86578db95"
Unverified Commit 20fa8289 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Make default_data_collator more flexible and deprecate old behavior (#5060)

* Make default_data_collator more flexible

* Accept tensors for all features

* Document code

* Refactor

* Formatting
parent 5e069633
...@@ -33,31 +33,34 @@ def default_data_collator(features: List[InputDataClass]) -> Dict[str, torch.Ten ...@@ -33,31 +33,34 @@ def default_data_collator(features: List[InputDataClass]) -> Dict[str, torch.Ten
# have the same attributes. # have the same attributes.
# So we will look at the first element as a proxy for what attributes exist # So we will look at the first element as a proxy for what attributes exist
# on the whole batch. # on the whole batch.
if not isinstance(features[0], dict):
features = [vars(f) for f in features]
first = features[0] first = features[0]
batch = {}
# Special handling for labels. # Special handling for labels.
# Ensure that tensor is created with the correct type # Ensure that tensor is created with the correct type
# (it should be automatically the case, but let's make sure of it.) # (it should be automatically the case, but let's make sure of it.)
if hasattr(first, "label") and first.label is not None: if "label" in first:
if type(first.label) is int: dtype = torch.long if type(first["label"]) is int else torch.float
labels = torch.tensor([f.label for f in features], dtype=torch.long) batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype)
else: elif "label_ids" in first:
labels = torch.tensor([f.label for f in features], dtype=torch.float) if isinstance(first["label_ids"], torch.Tensor):
batch = {"labels": labels} batch["labels"] = torch.stack([f["label_ids"] for f in features])
elif hasattr(first, "label_ids") and first.label_ids is not None:
if type(first.label_ids[0]) is int:
labels = torch.tensor([f.label_ids for f in features], dtype=torch.long)
else:
labels = torch.tensor([f.label_ids for f in features], dtype=torch.float)
batch = {"labels": labels}
else: else:
batch = {} dtype = torch.long if type(first["label_ids"][0]) is int else torch.float
batch["labels"] = torch.tensor([f["label_ids"] for f in features], dtype=dtype)
# Handling of all other possible attributes. # Handling of all other possible keys.
# Again, we will use the first element to figure out which key/values are not None for this model. # Again, we will use the first element to figure out which key/values are not None for this model.
for k, v in vars(first).items(): for k, v in first.items():
if k not in ("label", "label_ids") and v is not None and not isinstance(v, str): if k not in ("label", "label_ids") and v is not None and not isinstance(v, str):
batch[k] = torch.tensor([getattr(f, k) for f in features], dtype=torch.long) if isinstance(v, torch.Tensor):
batch[k] = torch.stack([f[k] for f in features])
else:
batch[k] = torch.tensor([f[k] for f in features], dtype=torch.long)
return batch return batch
......
...@@ -4,6 +4,7 @@ import os ...@@ -4,6 +4,7 @@ import os
import random import random
import re import re
import shutil import shutil
import warnings
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from typing import Callable, Dict, List, Optional, Tuple from typing import Callable, Dict, List, Optional, Tuple
...@@ -205,6 +206,15 @@ class Trainer: ...@@ -205,6 +206,15 @@ class Trainer:
# Set an xla_device flag on the model's config. # Set an xla_device flag on the model's config.
# We'll find a more elegant and not need to do this in the future. # We'll find a more elegant and not need to do this in the future.
self.model.config.xla_device = True self.model.config.xla_device = True
if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)):
self.data_collator = self.data_collator.collate_batch
warnings.warn(
(
"The `data_collator` should now be a simple callable (function, class with `__call__`), classes "
+ "with a `collate_batch` are deprecated and won't be supported in a future version."
),
FutureWarning,
)
def get_train_dataloader(self) -> DataLoader: def get_train_dataloader(self) -> DataLoader:
if self.train_dataset is None: if self.train_dataset is None:
......
...@@ -24,6 +24,27 @@ PATH_SAMPLE_TEXT = "./tests/fixtures/sample_text.txt" ...@@ -24,6 +24,27 @@ PATH_SAMPLE_TEXT = "./tests/fixtures/sample_text.txt"
@require_torch @require_torch
class DataCollatorIntegrationTest(unittest.TestCase): class DataCollatorIntegrationTest(unittest.TestCase):
def test_default_with_dict(self):
features = [{"labels": i, "inputs": [0, 1, 2, 3, 4, 5]} for i in range(8)]
batch = default_data_collator(features)
self.assertTrue(batch["labels"].equal(torch.tensor(list(range(8)))))
self.assertEqual(batch["labels"].dtype, torch.long)
self.assertEqual(batch["inputs"].shape, torch.Size([8, 6]))
# With label_ids
features = [{"label_ids": [0, 1, 2], "inputs": [0, 1, 2, 3, 4, 5]} for i in range(8)]
batch = default_data_collator(features)
self.assertTrue(batch["labels"].equal(torch.tensor([[0, 1, 2]] * 8)))
self.assertEqual(batch["labels"].dtype, torch.long)
self.assertEqual(batch["inputs"].shape, torch.Size([8, 6]))
# Features can already be tensors
features = [{"labels": i, "inputs": torch.randint(10, [10])} for i in range(8)]
batch = default_data_collator(features)
self.assertTrue(batch["labels"].equal(torch.tensor(list(range(8)))))
self.assertEqual(batch["labels"].dtype, torch.long)
self.assertEqual(batch["inputs"].shape, torch.Size([8, 10]))
def test_default_classification(self): def test_default_classification(self):
MODEL_ID = "bert-base-cased-finetuned-mrpc" MODEL_ID = "bert-base-cased-finetuned-mrpc"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment