Commit 0a850d21 authored by sgugger's avatar sgugger
Browse files

Missing commit

parent b30879fe
......@@ -257,10 +257,7 @@ class Trainer:
logger.info(
f"The following columns {dset_description}don't have a corresponding argument in `{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}."
)
ds_type = dataset.format["type"]
if ds_type == "python":
ds_type = None
dataset.set_format(type=ds_type, columns=columns)
dataset.set_format(type=dataset.format["type"], columns=columns)
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment