Unverified Commit de46cde1 authored by Matt's avatar Matt Committed by GitHub
Browse files

Drop columns after loading samples in prepare_tf_dataset (#17967)

* Drop columns after loading samples, rather than before, to avoid breaking transforms

* make fixup

* Add workaround so this PR can work with current datasets version
parent 2544c143
...@@ -1238,6 +1238,17 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -1238,6 +1238,17 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
raise TypeError("Dataset argument should be a datasets.Dataset!") raise TypeError("Dataset argument should be a datasets.Dataset!")
model_inputs = list(dict(inspect.signature(self.call).parameters).keys()) model_inputs = list(dict(inspect.signature(self.call).parameters).keys())
model_labels = find_labels(self.__class__) model_labels = find_labels(self.__class__)
if "cols_to_retain" in list(inspect.signature(dataset._get_output_signature).parameters.keys()):
output_signature, _ = dataset._get_output_signature(
dataset,
batch_size=None,
collate_fn=collate_fn,
collate_fn_args=collate_fn_args,
cols_to_retain=model_inputs,
)
else:
# TODO Matt: This is a workaround for older versions of datasets that are missing the `cols_to_retain`
# argument. We should remove this once the minimum supported version of datasets is > 2.3.2
unwanted_columns = [ unwanted_columns = [
feature feature
for feature in dataset.features for feature in dataset.features
...@@ -1245,10 +1256,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -1245,10 +1256,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
] ]
dataset = dataset.remove_columns(unwanted_columns) dataset = dataset.remove_columns(unwanted_columns)
output_signature, _ = dataset._get_output_signature( output_signature, _ = dataset._get_output_signature(
dataset, dataset, batch_size=None, collate_fn=collate_fn, collate_fn_args=collate_fn_args
batch_size=None,
collate_fn=collate_fn,
collate_fn_args=collate_fn_args,
) )
output_columns = list(output_signature.keys()) output_columns = list(output_signature.keys())
feature_cols = [col for col in output_columns if col in model_inputs and col not in model_labels] feature_cols = [col for col in output_columns if col in model_inputs and col not in model_labels]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment