"git@developer.sourcefind.cn:zhaoyu6/sglang.git" did not exist on "287427e2e66aef4e4d857cfd666fe849e9f73617"
Commit ad423d06 authored by Maxim Neumann's avatar Maxim Neumann Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 315855426
parent 0b23ad50
......@@ -33,7 +33,7 @@ from official.nlp.bert import tokenization
class InputExample(object):
"""A single training/test example for simple sequence classification."""
def __init__(self, guid, text_a, text_b=None, label=None):
def __init__(self, guid, text_a, text_b=None, label=None, weight=None):
"""Constructs a InputExample.
Args:
......@@ -44,11 +44,14 @@ class InputExample(object):
Only must be specified for sequence pair tasks.
label: (Optional) string. The label of the example. This should be
specified for train and dev examples, but not for test examples.
weight: (Optional) float. The weight of the example to be used during
training.
"""
self.guid = guid
self.text_a = text_a
self.text_b = text_b
self.label = label
self.weight = weight
class InputFeatures(object):
......@@ -59,12 +62,14 @@ class InputFeatures(object):
input_mask,
segment_ids,
label_id,
is_real_example=True):
is_real_example=True,
weight=None):
self.input_ids = input_ids
self.input_mask = input_mask
self.segment_ids = segment_ids
self.label_id = label_id
self.is_real_example = is_real_example
self.weight = weight
class DataProcessor(object):
......@@ -574,6 +579,7 @@ class TfdsProcessor(DataProcessor):
test_text_b_key: Key of the second text feature to use in test set.
test_label: String to be used as the label for all test examples.
label_type: Type of the label key (defaults to `int`).
weight_key: Key of the float sample weight (is not used if not provided).
is_regression: Whether the task is a regression problem (defaults to False).
"""
......@@ -612,6 +618,7 @@ class TfdsProcessor(DataProcessor):
self.test_label = d.get("test_label", "test_example")
self.label_type = dtype_map[d.get("label_type", "int")]
self.is_regression = cast_str_to_bool(d.get("is_regression", "False"))
self.weight_key = d.get("weight_key", None)
def get_train_examples(self, data_dir):
assert data_dir is None
......@@ -637,7 +644,7 @@ class TfdsProcessor(DataProcessor):
raise ValueError("Split {} not available.".format(split_name))
dataset = self.dataset[split_name].as_numpy_iterator()
examples = []
text_b = None
text_b, weight = None, None
for i, example in enumerate(dataset):
guid = "%s-%s" % (set_type, i)
if set_type == "test":
......@@ -650,8 +657,11 @@ class TfdsProcessor(DataProcessor):
if self.text_b_key:
text_b = self.process_text_fn(example[self.text_b_key])
label = self.label_type(example[self.label_key])
if self.weight_key:
weight = float(example[self.weight_key])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label,
weight=weight))
return examples
......@@ -739,13 +749,15 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
logging.info("input_mask: %s", " ".join([str(x) for x in input_mask]))
logging.info("segment_ids: %s", " ".join([str(x) for x in segment_ids]))
logging.info("label: %s (id = %d)", example.label, label_id)
logging.info("weight: %s", example.weight)
feature = InputFeatures(
input_ids=input_ids,
input_mask=input_mask,
segment_ids=segment_ids,
label_id=label_id,
is_real_example=True)
is_real_example=True,
weight=example.weight)
return feature
......@@ -781,6 +793,8 @@ def file_based_convert_examples_to_features(examples, label_list,
features["label_ids"] = create_int_feature([feature.label_id])
features["is_real_example"] = create_int_feature(
[int(feature.is_real_example)])
if feature.weight is not None:
features["weight"] = create_float_feature([feature.weight])
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
writer.write(tf_example.SerializeToString())
......@@ -837,6 +851,7 @@ def generate_tf_record_from_data_file(processor,
label_list = processor.get_labels()
label_type = getattr(processor, "label_type", None)
is_regression = getattr(processor, "is_regression", False)
has_sample_weights = getattr(processor, "weight_key", False)
assert train_data_output_path
train_input_data_examples = processor.get_train_examples(data_dir)
......@@ -879,6 +894,8 @@ def generate_tf_record_from_data_file(processor,
else:
meta_data["task_type"] = "bert_classification"
meta_data["num_labels"] = len(processor.get_labels())
if has_sample_weights:
meta_data["has_sample_weights"] = True
if eval_data_output_path:
meta_data["eval_data_size"] = len(eval_input_data_examples)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment