Unverified Commit de46cde1 authored by Matt's avatar Matt Committed by GitHub
Browse files

Drop columns after loading samples in prepare_tf_dataset (#17967)

* Drop columns after loading samples, rather than before, to avoid breaking transforms

* make fixup

* Add workaround so this PR can work with current datasets version
parent 2544c143
......@@ -1238,18 +1238,26 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
raise TypeError("Dataset argument should be a datasets.Dataset!")
model_inputs = list(dict(inspect.signature(self.call).parameters).keys())
model_labels = find_labels(self.__class__)
unwanted_columns = [
feature
for feature in dataset.features
if feature not in model_inputs and feature not in ("label_ids", "label")
]
dataset = dataset.remove_columns(unwanted_columns)
output_signature, _ = dataset._get_output_signature(
dataset,
batch_size=None,
collate_fn=collate_fn,
collate_fn_args=collate_fn_args,
)
if "cols_to_retain" in list(inspect.signature(dataset._get_output_signature).parameters.keys()):
output_signature, _ = dataset._get_output_signature(
dataset,
batch_size=None,
collate_fn=collate_fn,
collate_fn_args=collate_fn_args,
cols_to_retain=model_inputs,
)
else:
# TODO Matt: This is a workaround for older versions of datasets that are missing the `cols_to_retain`
# argument. We should remove this once the minimum supported version of datasets is > 2.3.2
unwanted_columns = [
feature
for feature in dataset.features
if feature not in model_inputs and feature not in ("label_ids", "label")
]
dataset = dataset.remove_columns(unwanted_columns)
output_signature, _ = dataset._get_output_signature(
dataset, batch_size=None, collate_fn=collate_fn, collate_fn_args=collate_fn_args
)
output_columns = list(output_signature.keys())
feature_cols = [col for col in output_columns if col in model_inputs and col not in model_labels]
label_cols = [col for col in output_columns if col in model_labels]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment