Commit ad423d06 authored by Maxim Neumann's avatar Maxim Neumann Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 315855426
parent 0b23ad50
...@@ -33,7 +33,7 @@ from official.nlp.bert import tokenization ...@@ -33,7 +33,7 @@ from official.nlp.bert import tokenization
class InputExample(object): class InputExample(object):
"""A single training/test example for simple sequence classification.""" """A single training/test example for simple sequence classification."""
def __init__(self, guid, text_a, text_b=None, label=None): def __init__(self, guid, text_a, text_b=None, label=None, weight=None):
"""Constructs a InputExample. """Constructs a InputExample.
Args: Args:
...@@ -44,11 +44,14 @@ class InputExample(object): ...@@ -44,11 +44,14 @@ class InputExample(object):
Only must be specified for sequence pair tasks. Only must be specified for sequence pair tasks.
label: (Optional) string. The label of the example. This should be label: (Optional) string. The label of the example. This should be
specified for train and dev examples, but not for test examples. specified for train and dev examples, but not for test examples.
weight: (Optional) float. The weight of the example to be used during
training.
""" """
self.guid = guid self.guid = guid
self.text_a = text_a self.text_a = text_a
self.text_b = text_b self.text_b = text_b
self.label = label self.label = label
self.weight = weight
class InputFeatures(object): class InputFeatures(object):
...@@ -59,12 +62,14 @@ class InputFeatures(object): ...@@ -59,12 +62,14 @@ class InputFeatures(object):
input_mask, input_mask,
segment_ids, segment_ids,
label_id, label_id,
is_real_example=True): is_real_example=True,
weight=None):
self.input_ids = input_ids self.input_ids = input_ids
self.input_mask = input_mask self.input_mask = input_mask
self.segment_ids = segment_ids self.segment_ids = segment_ids
self.label_id = label_id self.label_id = label_id
self.is_real_example = is_real_example self.is_real_example = is_real_example
self.weight = weight
class DataProcessor(object): class DataProcessor(object):
...@@ -574,6 +579,7 @@ class TfdsProcessor(DataProcessor): ...@@ -574,6 +579,7 @@ class TfdsProcessor(DataProcessor):
test_text_b_key: Key of the second text feature to use in test set. test_text_b_key: Key of the second text feature to use in test set.
test_label: String to be used as the label for all test examples. test_label: String to be used as the label for all test examples.
label_type: Type of the label key (defaults to `int`). label_type: Type of the label key (defaults to `int`).
weight_key: Key of the float sample weight (is not used if not provided).
is_regression: Whether the task is a regression problem (defaults to False). is_regression: Whether the task is a regression problem (defaults to False).
""" """
...@@ -612,6 +618,7 @@ class TfdsProcessor(DataProcessor): ...@@ -612,6 +618,7 @@ class TfdsProcessor(DataProcessor):
self.test_label = d.get("test_label", "test_example") self.test_label = d.get("test_label", "test_example")
self.label_type = dtype_map[d.get("label_type", "int")] self.label_type = dtype_map[d.get("label_type", "int")]
self.is_regression = cast_str_to_bool(d.get("is_regression", "False")) self.is_regression = cast_str_to_bool(d.get("is_regression", "False"))
self.weight_key = d.get("weight_key", None)
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
assert data_dir is None assert data_dir is None
...@@ -637,7 +644,7 @@ class TfdsProcessor(DataProcessor): ...@@ -637,7 +644,7 @@ class TfdsProcessor(DataProcessor):
raise ValueError("Split {} not available.".format(split_name)) raise ValueError("Split {} not available.".format(split_name))
dataset = self.dataset[split_name].as_numpy_iterator() dataset = self.dataset[split_name].as_numpy_iterator()
examples = [] examples = []
text_b = None text_b, weight = None, None
for i, example in enumerate(dataset): for i, example in enumerate(dataset):
guid = "%s-%s" % (set_type, i) guid = "%s-%s" % (set_type, i)
if set_type == "test": if set_type == "test":
...@@ -650,8 +657,11 @@ class TfdsProcessor(DataProcessor): ...@@ -650,8 +657,11 @@ class TfdsProcessor(DataProcessor):
if self.text_b_key: if self.text_b_key:
text_b = self.process_text_fn(example[self.text_b_key]) text_b = self.process_text_fn(example[self.text_b_key])
label = self.label_type(example[self.label_key]) label = self.label_type(example[self.label_key])
if self.weight_key:
weight = float(example[self.weight_key])
examples.append( examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label,
weight=weight))
return examples return examples
...@@ -739,13 +749,15 @@ def convert_single_example(ex_index, example, label_list, max_seq_length, ...@@ -739,13 +749,15 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
logging.info("input_mask: %s", " ".join([str(x) for x in input_mask])) logging.info("input_mask: %s", " ".join([str(x) for x in input_mask]))
logging.info("segment_ids: %s", " ".join([str(x) for x in segment_ids])) logging.info("segment_ids: %s", " ".join([str(x) for x in segment_ids]))
logging.info("label: %s (id = %d)", example.label, label_id) logging.info("label: %s (id = %d)", example.label, label_id)
logging.info("weight: %s", example.weight)
feature = InputFeatures( feature = InputFeatures(
input_ids=input_ids, input_ids=input_ids,
input_mask=input_mask, input_mask=input_mask,
segment_ids=segment_ids, segment_ids=segment_ids,
label_id=label_id, label_id=label_id,
is_real_example=True) is_real_example=True,
weight=example.weight)
return feature return feature
...@@ -781,6 +793,8 @@ def file_based_convert_examples_to_features(examples, label_list, ...@@ -781,6 +793,8 @@ def file_based_convert_examples_to_features(examples, label_list,
features["label_ids"] = create_int_feature([feature.label_id]) features["label_ids"] = create_int_feature([feature.label_id])
features["is_real_example"] = create_int_feature( features["is_real_example"] = create_int_feature(
[int(feature.is_real_example)]) [int(feature.is_real_example)])
if feature.weight is not None:
features["weight"] = create_float_feature([feature.weight])
tf_example = tf.train.Example(features=tf.train.Features(feature=features)) tf_example = tf.train.Example(features=tf.train.Features(feature=features))
writer.write(tf_example.SerializeToString()) writer.write(tf_example.SerializeToString())
...@@ -837,6 +851,7 @@ def generate_tf_record_from_data_file(processor, ...@@ -837,6 +851,7 @@ def generate_tf_record_from_data_file(processor,
label_list = processor.get_labels() label_list = processor.get_labels()
label_type = getattr(processor, "label_type", None) label_type = getattr(processor, "label_type", None)
is_regression = getattr(processor, "is_regression", False) is_regression = getattr(processor, "is_regression", False)
has_sample_weights = getattr(processor, "weight_key", False)
assert train_data_output_path assert train_data_output_path
train_input_data_examples = processor.get_train_examples(data_dir) train_input_data_examples = processor.get_train_examples(data_dir)
...@@ -879,6 +894,8 @@ def generate_tf_record_from_data_file(processor, ...@@ -879,6 +894,8 @@ def generate_tf_record_from_data_file(processor,
else: else:
meta_data["task_type"] = "bert_classification" meta_data["task_type"] = "bert_classification"
meta_data["num_labels"] = len(processor.get_labels()) meta_data["num_labels"] = len(processor.get_labels())
if has_sample_weights:
meta_data["has_sample_weights"] = True
if eval_data_output_path: if eval_data_output_path:
meta_data["eval_data_size"] = len(eval_input_data_examples) meta_data["eval_data_size"] = len(eval_input_data_examples)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment