Unverified Commit b30879fe authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Don't reset the dataset type + plug for rm unused columns (#6683)



* Don't reset the type of the dataset

* Formatting

* Update trainer.py
Co-authored-by: default avatarTeven <teven.lescao@gmail.com>
parent 1a779ad7
...@@ -244,6 +244,8 @@ class Trainer: ...@@ -244,6 +244,8 @@ class Trainer:
self.scaler = torch.cuda.amp.GradScaler() self.scaler = torch.cuda.amp.GradScaler()
def _remove_unused_columns(self, dataset: "nlp.Dataset", description: Optional[str] = None): def _remove_unused_columns(self, dataset: "nlp.Dataset", description: Optional[str] = None):
if not self.args.remove_unused_columns:
return
# Inspect model forward signature to keep only the arguments it accepts. # Inspect model forward signature to keep only the arguments it accepts.
signature = inspect.signature(self.model.forward) signature = inspect.signature(self.model.forward)
signature_columns = list(signature.parameters.keys()) signature_columns = list(signature.parameters.keys())
...@@ -255,7 +257,10 @@ class Trainer: ...@@ -255,7 +257,10 @@ class Trainer:
logger.info( logger.info(
f"The following columns {dset_description}don't have a corresponding argument in `{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}." f"The following columns {dset_description}don't have a corresponding argument in `{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}."
) )
dataset.set_format(columns=columns) ds_type = dataset.format["type"]
if ds_type == "python":
ds_type = None
dataset.set_format(type=ds_type, columns=columns)
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]: def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
if isinstance(self.train_dataset, torch.utils.data.IterableDataset): if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
......
...@@ -114,6 +114,11 @@ class TrainingArguments: ...@@ -114,6 +114,11 @@ class TrainingArguments:
at the next training step under the keyword argument ``mems``. at the next training step under the keyword argument ``mems``.
run_name (:obj:`str`, `optional`): run_name (:obj:`str`, `optional`):
A descriptor for the run. Notably used for wandb logging. A descriptor for the run. Notably used for wandb logging.
remove_unused_columns (:obj:`bool`, `optional`, defaults to :obj:`True`):
If using `nlp.Dataset` datasets, whether or not to automatically remove the columns unused by the model
forward method.
(Note: this behavior is not implemented for :class:`~transformers.TFTrainer` yet.)
""" """
output_dir: str = field( output_dir: str = field(
...@@ -234,6 +239,10 @@ class TrainingArguments: ...@@ -234,6 +239,10 @@ class TrainingArguments:
default=None, metadata={"help": "An optional descriptor for the run. Notably used for wandb logging."} default=None, metadata={"help": "An optional descriptor for the run. Notably used for wandb logging."}
) )
remove_unused_columns: Optional[bool] = field(
default=True, metadata={"help": "Remove columns not required by the model when using an nlp.Dataset."}
)
@property @property
def train_batch_size(self) -> int: def train_batch_size(self) -> int:
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment