Unverified Commit 7b95825d authored by Antoni Baum's avatar Antoni Baum Committed by GitHub
Browse files

Remove columns before passing to data collator (#17187)

parent 934e21cd
...@@ -607,13 +607,14 @@ class Trainer: ...@@ -607,13 +607,14 @@ class Trainer:
# Inspect model forward signature to keep only the arguments it accepts. # Inspect model forward signature to keep only the arguments it accepts.
signature = inspect.signature(self.model.forward) signature = inspect.signature(self.model.forward)
self._signature_columns = list(signature.parameters.keys()) self._signature_columns = list(signature.parameters.keys())
# Labels may be named label or label_ids, the default data collator handles that.
self._signature_columns += list(set(["label", "label_ids"] + self.label_names))
def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None): def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
if not self.args.remove_unused_columns: if not self.args.remove_unused_columns:
return dataset return dataset
self._set_signature_columns_if_needed() self._set_signature_columns_if_needed()
# Labels may be named label or label_ids, the default data collator handles that. signature_columns = self._signature_columns
signature_columns = self._signature_columns + ["label", "label_ids"]
ignored_columns = list(set(dataset.column_names) - set(signature_columns)) ignored_columns = list(set(dataset.column_names) - set(signature_columns))
if len(ignored_columns) > 0: if len(ignored_columns) > 0:
...@@ -642,7 +643,7 @@ class Trainer: ...@@ -642,7 +643,7 @@ class Trainer:
if not self.args.remove_unused_columns: if not self.args.remove_unused_columns:
return data_collator return data_collator
self._set_signature_columns_if_needed() self._set_signature_columns_if_needed()
signature_columns = self._signature_columns + self.label_names signature_columns = self._signature_columns
remove_columns_collator = RemoveColumnsCollator( remove_columns_collator = RemoveColumnsCollator(
data_collator=data_collator, data_collator=data_collator,
......
...@@ -658,7 +658,7 @@ class FSDPOption(ExplicitEnum): ...@@ -658,7 +658,7 @@ class FSDPOption(ExplicitEnum):
class RemoveColumnsCollator: class RemoveColumnsCollator:
"""Wrap the data collator to remove unused columns from its output.""" """Wrap the data collator to remove unused columns before they are passed to the collator."""
def __init__( def __init__(
self, self,
...@@ -690,4 +690,5 @@ class RemoveColumnsCollator: ...@@ -690,4 +690,5 @@ class RemoveColumnsCollator:
return {k: v for k, v in feature.items() if k in self.signature_columns} return {k: v for k, v in feature.items() if k in self.signature_columns}
def __call__(self, features: List[dict]): def __call__(self, features: List[dict]):
return self._remove_columns(self.data_collator(features)) features = [self._remove_columns(feature) for feature in features]
return self.data_collator(features)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment