Commit 2a38d9a4 authored by Maxim Neumann's avatar Maxim Neumann Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 320144607
parent 52bb4ab1
...@@ -729,6 +729,8 @@ class TfdsProcessor(DataProcessor): ...@@ -729,6 +729,8 @@ class TfdsProcessor(DataProcessor):
tfds_params="dataset=glue/mrpc,text_key=sentence1,text_b_key=sentence2" tfds_params="dataset=glue/mrpc,text_key=sentence1,text_b_key=sentence2"
tfds_params="dataset=glue/stsb,text_key=sentence1,text_b_key=sentence2," tfds_params="dataset=glue/stsb,text_key=sentence1,text_b_key=sentence2,"
"is_regression=true,label_type=float" "is_regression=true,label_type=float"
tfds_params="dataset=snli,text_key=premise,text_b_key=hypothesis,"
"skip_label=-1"
Possible parameters (please refer to the documentation of Tensorflow Datasets Possible parameters (please refer to the documentation of Tensorflow Datasets
(TFDS) for the meaning of individual parameters): (TFDS) for the meaning of individual parameters):
dataset: Required dataset name (potentially with subset and version number). dataset: Required dataset name (potentially with subset and version number).
...@@ -746,6 +748,7 @@ class TfdsProcessor(DataProcessor): ...@@ -746,6 +748,7 @@ class TfdsProcessor(DataProcessor):
label_type: Type of the label key (defaults to `int`). label_type: Type of the label key (defaults to `int`).
weight_key: Key of the float sample weight (is not used if not provided). weight_key: Key of the float sample weight (is not used if not provided).
is_regression: Whether the task is a regression problem (defaults to False). is_regression: Whether the task is a regression problem (defaults to False).
skip_label: Skip examples with given label (defaults to None).
""" """
def __init__(self, def __init__(self,
...@@ -785,6 +788,9 @@ class TfdsProcessor(DataProcessor): ...@@ -785,6 +788,9 @@ class TfdsProcessor(DataProcessor):
self.label_type = dtype_map[d.get("label_type", "int")] self.label_type = dtype_map[d.get("label_type", "int")]
self.is_regression = cast_str_to_bool(d.get("is_regression", "False")) self.is_regression = cast_str_to_bool(d.get("is_regression", "False"))
self.weight_key = d.get("weight_key", None) self.weight_key = d.get("weight_key", None)
self.skip_label = d.get("skip_label", None)
if self.skip_label is not None:
self.skip_label = self.label_type(self.skip_label)
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
assert data_dir is None assert data_dir is None
...@@ -823,6 +829,8 @@ class TfdsProcessor(DataProcessor): ...@@ -823,6 +829,8 @@ class TfdsProcessor(DataProcessor):
if self.text_b_key: if self.text_b_key:
text_b = self.process_text_fn(example[self.text_b_key]) text_b = self.process_text_fn(example[self.text_b_key])
label = self.label_type(example[self.label_key]) label = self.label_type(example[self.label_key])
if self.skip_label is not None and label == self.skip_label:
continue
if self.weight_key: if self.weight_key:
weight = float(example[self.weight_key]) weight = float(example[self.weight_key])
examples.append( examples.append(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment