Unverified Commit edcc66d2 authored by Antoni Baum's avatar Antoni Baum Committed by GitHub
Browse files

Remove unnecessary columns for all dataset types in `Trainer` (#17166)

* Remove unneeded columns for IterableDataset

* Add test

* Update trainer tests

* Edit docstring

* Lint

* Apply feedback

* Apply feedback
parent c33f6046
......@@ -109,6 +109,7 @@ from .trainer_utils import (
HubStrategy,
IntervalStrategy,
PredictionOutput,
RemoveColumnsCollator,
ShardedDDPOption,
TrainerMemoryTracker,
TrainOutput,
......@@ -601,27 +602,30 @@ class Trainer:
if self.args.parallel_mode == ParallelMode.TPU and hasattr(model, "tie_weights"):
model.tie_weights()
def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
if not self.args.remove_unused_columns:
return dataset
def _set_signature_columns_if_needed(self):
if self._signature_columns is None:
# Inspect model forward signature to keep only the arguments it accepts.
signature = inspect.signature(self.model.forward)
self._signature_columns = list(signature.parameters.keys())
def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
if not self.args.remove_unused_columns:
return dataset
self._set_signature_columns_if_needed()
# Labels may be named label or label_ids, the default data collator handles that.
self._signature_columns += ["label", "label_ids"]
signature_columns = self._signature_columns + ["label", "label_ids"]
ignored_columns = list(set(dataset.column_names) - set(self._signature_columns))
ignored_columns = list(set(dataset.column_names) - set(signature_columns))
if len(ignored_columns) > 0:
dset_description = "" if description is None else f"in the {description} set "
dset_description = "" if description is None else f"in the {description} set"
logger.info(
f"The following columns {dset_description} don't have a corresponding argument in "
f"`{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}."
f" If {', '.join(ignored_columns)} are not expected by `{self.model.__class__.__name__}.forward`, "
f" you can safely ignore this message."
" you can safely ignore this message."
)
columns = [k for k in self._signature_columns if k in dataset.column_names]
columns = [k for k in signature_columns if k in dataset.column_names]
if version.parse(datasets.__version__) < version.parse("1.4.0"):
dataset.set_format(
......@@ -631,6 +635,24 @@ class Trainer:
else:
return dataset.remove_columns(ignored_columns)
def _get_collator_with_removed_columns(
self, data_collator: Callable, description: Optional[str] = None
) -> Callable:
"""Wrap the data collator in a callable removing unused columns."""
if not self.args.remove_unused_columns:
return data_collator
self._set_signature_columns_if_needed()
signature_columns = self._signature_columns + self.label_names
remove_columns_collator = RemoveColumnsCollator(
data_collator=data_collator,
signature_columns=signature_columns,
logger=logger,
description=description,
model_name=self.model.__class__.__name__,
)
return remove_columns_collator
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.train_dataset is None or not has_length(self.train_dataset):
return None
......@@ -717,8 +739,11 @@ class Trainer:
raise ValueError("Trainer: training requires a train_dataset.")
train_dataset = self.train_dataset
data_collator = self.data_collator
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
train_dataset = self._remove_unused_columns(train_dataset, description="training")
else:
data_collator = self._get_collator_with_removed_columns(data_collator, description="training")
if isinstance(train_dataset, torch.utils.data.IterableDataset):
if self.args.world_size > 1:
......@@ -733,7 +758,7 @@ class Trainer:
return DataLoader(
train_dataset,
batch_size=self.args.per_device_train_batch_size,
collate_fn=self.data_collator,
collate_fn=data_collator,
num_workers=self.args.dataloader_num_workers,
pin_memory=self.args.dataloader_pin_memory,
)
......@@ -744,7 +769,7 @@ class Trainer:
train_dataset,
batch_size=self._train_batch_size,
sampler=train_sampler,
collate_fn=self.data_collator,
collate_fn=data_collator,
drop_last=self.args.dataloader_drop_last,
num_workers=self.args.dataloader_num_workers,
pin_memory=self.args.dataloader_pin_memory,
......@@ -794,9 +819,12 @@ class Trainer:
if eval_dataset is None and self.eval_dataset is None:
raise ValueError("Trainer: evaluation requires an eval_dataset.")
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
data_collator = self.data_collator
if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation")
else:
data_collator = self._get_collator_with_removed_columns(data_collator, description="evaluation")
if isinstance(eval_dataset, torch.utils.data.IterableDataset):
if self.args.world_size > 1:
......@@ -810,7 +838,7 @@ class Trainer:
return DataLoader(
eval_dataset,
batch_size=self.args.eval_batch_size,
collate_fn=self.data_collator,
collate_fn=data_collator,
num_workers=self.args.dataloader_num_workers,
pin_memory=self.args.dataloader_pin_memory,
)
......@@ -821,7 +849,7 @@ class Trainer:
eval_dataset,
sampler=eval_sampler,
batch_size=self.args.eval_batch_size,
collate_fn=self.data_collator,
collate_fn=data_collator,
drop_last=self.args.dataloader_drop_last,
num_workers=self.args.dataloader_num_workers,
pin_memory=self.args.dataloader_pin_memory,
......@@ -838,8 +866,12 @@ class Trainer:
The test dataset to use. If it is an `datasets.Dataset`, columns not accepted by the `model.forward()`
method are automatically removed. It must implement `__len__`.
"""
data_collator = self.data_collator
if is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
test_dataset = self._remove_unused_columns(test_dataset, description="test")
else:
data_collator = self._get_collator_with_removed_columns(data_collator, description="test")
if isinstance(test_dataset, torch.utils.data.IterableDataset):
if self.args.world_size > 1:
......@@ -853,7 +885,7 @@ class Trainer:
return DataLoader(
test_dataset,
batch_size=self.args.eval_batch_size,
collate_fn=self.data_collator,
collate_fn=data_collator,
num_workers=self.args.dataloader_num_workers,
pin_memory=self.args.dataloader_pin_memory,
)
......@@ -865,7 +897,7 @@ class Trainer:
test_dataset,
sampler=test_sampler,
batch_size=self.args.eval_batch_size,
collate_fn=self.data_collator,
collate_fn=data_collator,
drop_last=self.args.dataloader_drop_last,
pin_memory=self.args.dataloader_pin_memory,
)
......
......@@ -25,7 +25,7 @@ import random
import re
import threading
import time
from typing import Any, Dict, NamedTuple, Optional, Tuple, Union
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union
import numpy as np
......@@ -655,3 +655,39 @@ class FSDPOption(ExplicitEnum):
SHARD_GRAD_OP = "shard_grad_op"
OFFLOAD = "offload"
AUTO_WRAP = "auto_wrap"
class RemoveColumnsCollator:
"""Wrap the data collator to remove unused columns from its output."""
def __init__(
self,
data_collator,
signature_columns,
logger=None,
model_name: Optional[str] = None,
description: Optional[str] = None,
):
self.data_collator = data_collator
self.signature_columns = signature_columns
self.logger = logger
self.description = description
self.model_name = model_name
self.message_logged = False
def _remove_columns(self, feature: dict) -> dict:
if not self.message_logged and self.logger and self.model_name:
ignored_columns = list(set(feature.keys()) - set(self.signature_columns))
if len(ignored_columns) > 0:
dset_description = "" if self.description is None else f"in the {self.description} set"
self.logger.info(
f"The following columns {dset_description} don't have a corresponding argument in "
f"`{self.model_name}.forward` and have been ignored: {', '.join(ignored_columns)}."
f" If {', '.join(ignored_columns)} are not expected by `{self.model_name}.forward`, "
" you can safely ignore this message."
)
self.message_logged = True
return {k: v for k, v in feature.items() if k in self.signature_columns}
def __call__(self, features: List[dict]):
return self._remove_columns(self.data_collator(features))
......@@ -289,8 +289,7 @@ class TrainingArguments:
[`~notebook.NotebookTrainingTracker`] in Jupyter Notebooks. Will default to `True` if the logging level is
set to warn or lower (default), `False` otherwise.
remove_unused_columns (`bool`, *optional*, defaults to `True`):
If using `datasets.Dataset` datasets, whether or not to automatically remove the columns unused by the
model forward method.
Whether or not to automatically remove the columns unused by the model forward method.
(Note that this behavior is not implemented for [`TFTrainer`] yet.)
label_names (`List[str]`, *optional*):
......
......@@ -1329,7 +1329,8 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
def test_training_iterable_dataset(self):
config = RegressionModelConfig()
model = RegressionPreTrainedModel(config)
train_dataset = SampleIterableDataset()
# Adding one column not used by the model should have no impact
train_dataset = SampleIterableDataset(label_names=["labels", "extra"])
args = RegressionTrainingArguments(output_dir="./examples", max_steps=4)
trainer = Trainer(model=model, args=args, train_dataset=train_dataset)
......@@ -1363,7 +1364,8 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
def test_evaluation_iterable_dataset(self):
config = RegressionModelConfig(a=1.5, b=2.5)
model = RegressionPreTrainedModel(config)
eval_dataset = SampleIterableDataset()
# Adding one column not used by the model should have no impact
eval_dataset = SampleIterableDataset(label_names=["labels", "extra"])
args = RegressionTrainingArguments(output_dir="./examples")
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset, compute_metrics=AlmostAccuracy())
......@@ -1400,7 +1402,8 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))
# With a number of elements not a round multiple of the batch size
test_dataset = SampleIterableDataset(length=66)
# Adding one column not used by the model should have no impact
test_dataset = SampleIterableDataset(length=66, label_names=["labels", "extra"])
preds = trainer.predict(test_dataset).predictions
x = test_dataset.dataset.x
self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))
......
......@@ -18,8 +18,9 @@ import unittest
import numpy as np
from transformers.data.data_collator import default_data_collator
from transformers.testing_utils import require_accelerate, require_torch
from transformers.trainer_utils import find_executable_batch_size
from transformers.trainer_utils import RemoveColumnsCollator, find_executable_batch_size
from transformers.utils import is_torch_available
......@@ -457,3 +458,28 @@ class TrainerUtilsTest(unittest.TestCase):
with self.assertRaises(RuntimeError) as cm:
mock_training_loop_function()
self.assertEqual("CUDA out of memory", cm.args[0])
def test_remove_columns_collator(self):
class MockLogger:
def __init__(self) -> None:
self.called = 0
def info(self, msg):
self.called += 1
self.last_msg = msg
data_batch = [
{"col1": 1, "col2": 2, "col3": 3},
{"col1": 1, "col2": 2, "col3": 3},
]
logger = MockLogger()
remove_columns_collator = RemoveColumnsCollator(
default_data_collator, ["col1", "col2"], logger, "model", "training"
)
self.assertNotIn("col3", remove_columns_collator(data_batch))
# check that the logging message is printed out only once
remove_columns_collator(data_batch)
remove_columns_collator(data_batch)
self.assertEqual(logger.called, 1)
self.assertIn("col3", logger.last_msg)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment