Unverified Commit de46cde1 authored by Matt's avatar Matt Committed by GitHub
Browse files

Drop columns after loading samples in prepare_tf_dataset (#17967)

* Drop columns after loading samples, rather than before, to avoid breaking transforms

* make fixup

* Add workaround so this PR can work with current datasets version
parent 2544c143
...@@ -1238,18 +1238,26 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -1238,18 +1238,26 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
raise TypeError("Dataset argument should be a datasets.Dataset!") raise TypeError("Dataset argument should be a datasets.Dataset!")
model_inputs = list(dict(inspect.signature(self.call).parameters).keys()) model_inputs = list(dict(inspect.signature(self.call).parameters).keys())
model_labels = find_labels(self.__class__) model_labels = find_labels(self.__class__)
unwanted_columns = [ if "cols_to_retain" in list(inspect.signature(dataset._get_output_signature).parameters.keys()):
feature output_signature, _ = dataset._get_output_signature(
for feature in dataset.features dataset,
if feature not in model_inputs and feature not in ("label_ids", "label") batch_size=None,
] collate_fn=collate_fn,
dataset = dataset.remove_columns(unwanted_columns) collate_fn_args=collate_fn_args,
output_signature, _ = dataset._get_output_signature( cols_to_retain=model_inputs,
dataset, )
batch_size=None, else:
collate_fn=collate_fn, # TODO Matt: This is a workaround for older versions of datasets that are missing the `cols_to_retain`
collate_fn_args=collate_fn_args, # argument. We should remove this once the minimum supported version of datasets is > 2.3.2
) unwanted_columns = [
feature
for feature in dataset.features
if feature not in model_inputs and feature not in ("label_ids", "label")
]
dataset = dataset.remove_columns(unwanted_columns)
output_signature, _ = dataset._get_output_signature(
dataset, batch_size=None, collate_fn=collate_fn, collate_fn_args=collate_fn_args
)
output_columns = list(output_signature.keys()) output_columns = list(output_signature.keys())
feature_cols = [col for col in output_columns if col in model_inputs and col not in model_labels] feature_cols = [col for col in output_columns if col in model_inputs and col not in model_labels]
label_cols = [col for col in output_columns if col in model_labels] label_cols = [col for col in output_columns if col in model_labels]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment