Commit b7798166 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 328206696
parent e69b6211
......@@ -262,7 +262,7 @@ def create_retrieval_dataset(file_path,
'input_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
'input_mask': tf.io.FixedLenFeature([seq_length], tf.int64),
'segment_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
'int_iden': tf.io.FixedLenFeature([1], tf.int64),
'example_id': tf.io.FixedLenFeature([1], tf.int64),
}
dataset = single_file_dataset(file_path, name_to_features)
......@@ -278,7 +278,7 @@ def create_retrieval_dataset(file_path,
'input_mask': record['input_mask'],
'input_type_ids': record['segment_ids']
}
y = record['int_iden']
y = record['example_id']
return (x, y)
dataset = dataset.map(
......
......@@ -39,7 +39,7 @@ class InputExample(object):
text_b=None,
label=None,
weight=None,
int_iden=None):
example_id=None):
"""Constructs a InputExample.
Args:
......@@ -53,15 +53,15 @@ class InputExample(object):
examples, but not for test examples.
weight: (Optional) float. The weight of the example to be used during
training.
int_iden: (Optional) int. The int identification number of example in the
corpus.
example_id: (Optional) int. The int identification number of example in
the corpus.
"""
self.guid = guid
self.text_a = text_a
self.text_b = text_b
self.label = label
self.weight = weight
self.int_iden = int_iden
self.example_id = example_id
class InputFeatures(object):
......@@ -74,14 +74,14 @@ class InputFeatures(object):
label_id,
is_real_example=True,
weight=None,
int_iden=None):
example_id=None):
self.input_ids = input_ids
self.input_mask = input_mask
self.segment_ids = segment_ids
self.label_id = label_id
self.is_real_example = is_real_example
self.weight = weight
self.int_iden = int_iden
self.example_id = example_id
class DataProcessor(object):
......@@ -1050,7 +1050,7 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
logging.info("segment_ids: %s", " ".join([str(x) for x in segment_ids]))
logging.info("label: %s (id = %s)", example.label, str(label_id))
logging.info("weight: %s", example.weight)
logging.info("int_iden: %s", str(example.int_iden))
logging.info("example_id: %s", example.example_id)
feature = InputFeatures(
input_ids=input_ids,
......@@ -1059,7 +1059,7 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
label_id=label_id,
is_real_example=True,
weight=example.weight,
int_iden=example.int_iden)
example_id=example.example_id)
return feature
......@@ -1102,8 +1102,10 @@ def file_based_convert_examples_to_features(examples,
[int(feature.is_real_example)])
if feature.weight is not None:
features["weight"] = create_float_feature([feature.weight])
if feature.int_iden is not None:
features["int_iden"] = create_int_feature([feature.int_iden])
if feature.example_id is not None:
features["example_id"] = create_int_feature([feature.example_id])
else:
features["example_id"] = create_int_feature([ex_index])
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
writer.write(tf_example.SerializeToString())
......
......@@ -35,6 +35,8 @@ class SentencePredictionDataConfig(cfg.DataConfig):
is_training: bool = True
seq_length: int = 128
label_type: str = 'int'
# Whether to include the example id number.
include_example_id: bool = False
@data_loader_factory.register_data_loader_cls(SentencePredictionDataConfig)
......@@ -44,6 +46,7 @@ class SentencePredictionDataLoader(data_loader.DataLoader):
def __init__(self, params):
self._params = params
self._seq_length = params.seq_length
self._include_example_id = params.include_example_id
def _decode(self, record: tf.Tensor):
"""Decodes a serialized tf.Example."""
......@@ -54,6 +57,9 @@ class SentencePredictionDataLoader(data_loader.DataLoader):
'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'label_ids': tf.io.FixedLenFeature([], label_type),
}
if self._include_example_id:
name_to_features['example_id'] = tf.io.FixedLenFeature([], tf.int64)
example = tf.io.parse_single_example(record, name_to_features)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
......@@ -73,6 +79,9 @@ class SentencePredictionDataLoader(data_loader.DataLoader):
'input_mask': record['input_mask'],
'input_type_ids': record['segment_ids']
}
if self._include_example_id:
x['example_id'] = record['example_id']
y = record['label_ids']
return (x, y)
......
......@@ -49,11 +49,11 @@ class BuccProcessor(classifier_data_lib.DataProcessor):
examples = []
for (i, line) in enumerate(lines):
guid = "%s-%s" % (set_type, i)
int_iden = int(line[0].split("-")[1])
example_id = int(line[0].split("-")[1])
text_a = self.process_text_fn(line[1])
examples.append(
classifier_data_lib.InputExample(
guid=guid, text_a=text_a, int_iden=int_iden))
guid=guid, text_a=text_a, example_id=example_id))
return examples
......@@ -86,7 +86,7 @@ class TatoebaProcessor(classifier_data_lib.DataProcessor):
text_a = self.process_text_fn(line[0])
examples.append(
classifier_data_lib.InputExample(
guid=guid, text_a=text_a, int_iden=i))
guid=guid, text_a=text_a, example_id=i))
return examples
......
......@@ -246,22 +246,29 @@ def predict(task: SentencePredictionTask, params: cfg.DataConfig,
def predict_step(inputs):
"""Replicated prediction calculation."""
x, _ = inputs
example_id = x.pop('example_id')
outputs = task.inference_step(x, model)
if is_regression:
return outputs
return dict(example_id=example_id, predictions=outputs)
else:
return tf.argmax(outputs, axis=-1)
return dict(
example_id=example_id, predictions=tf.argmax(outputs, axis=-1))
def aggregate_fn(state, outputs):
"""Concatenates model's outputs."""
if state is None:
state = {'predictions': []}
state = []
for per_replica_batch_predictions in outputs:
state['predictions'].extend(per_replica_batch_predictions)
for per_replica_example_id, per_replica_batch_predictions in zip(
outputs['example_id'], outputs['predictions']):
state.extend(zip(per_replica_example_id, per_replica_batch_predictions))
return state
dataset = orbit.utils.make_distributed_dataset(tf.distribute.get_strategy(),
task.build_inputs, params)
outputs = utils.predict(predict_step, aggregate_fn, dataset)
return outputs['predictions']
# When running on TPU POD, the order of output cannot be maintained,
# so we need to sort by example_id.
outputs = sorted(outputs, key=lambda x: x[0])
return [x[1] for x in outputs]
......@@ -40,13 +40,14 @@ def _create_fake_dataset(output_path, seq_length, num_classes, num_examples):
def create_float_feature(values):
return tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
for _ in range(num_examples):
for i in range(num_examples):
features = {}
input_ids = np.random.randint(100, size=(seq_length))
features["input_ids"] = create_int_feature(input_ids)
features["input_mask"] = create_int_feature(np.ones_like(input_ids))
features["segment_ids"] = create_int_feature(np.ones_like(input_ids))
features["segment_ids"] = create_int_feature(np.ones_like(input_ids))
features["example_id"] = create_int_feature([i])
if num_classes == 1:
features["label_ids"] = create_float_feature([np.random.random()])
......@@ -250,7 +251,8 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
is_training=False,
label_type="int" if num_classes > 1 else "float",
global_batch_size=16,
drop_remainder=False))
drop_remainder=False,
include_example_id=True))
predictions = sentence_prediction.predict(task, test_data_config, model)
self.assertLen(predictions, num_examples)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment