"...csrc/git@developer.sourcefind.cn:change/sglang.git" did not exist on "7ad6b766c589cc51f4716b1d2052d66ac1a135fb"
Unverified Commit c79b550d authored by Jared T Nielsen's avatar Jared T Nielsen Committed by GitHub
Browse files

Add `qas_id` to SquadResult and SquadExample (#3745)

* Add qas_id

* Fix incorrect name in squad.py

* Make output files optional for squad eval
parent c4158a63
...@@ -307,7 +307,7 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -307,7 +307,7 @@ def evaluate(args, model, tokenizer, prefix=""):
if args.model_type in ["xlm", "roberta", "distilbert", "camembert"]: if args.model_type in ["xlm", "roberta", "distilbert", "camembert"]:
del inputs["token_type_ids"] del inputs["token_type_ids"]
example_indices = batch[3] feature_indices = batch[3]
# XLNet and XLM use more arguments for their predictions # XLNet and XLM use more arguments for their predictions
if args.model_type in ["xlnet", "xlm"]: if args.model_type in ["xlnet", "xlm"]:
...@@ -320,8 +320,9 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -320,8 +320,9 @@ def evaluate(args, model, tokenizer, prefix=""):
outputs = model(**inputs) outputs = model(**inputs)
for i, example_index in enumerate(example_indices): for i, feature_index in enumerate(feature_indices):
eval_feature = features[example_index.item()] # TODO: i and feature_index are the same number! Simplify by removing enumerate?
eval_feature = features[feature_index.item()]
unique_id = int(eval_feature.unique_id) unique_id = int(eval_feature.unique_id)
output = [to_list(output[i]) for output in outputs] output = [to_list(output[i]) for output in outputs]
......
...@@ -384,8 +384,12 @@ def compute_predictions_logits( ...@@ -384,8 +384,12 @@ def compute_predictions_logits(
tokenizer, tokenizer,
): ):
"""Write final predictions to the json file and log-odds of null if needed.""" """Write final predictions to the json file and log-odds of null if needed."""
logger.info("Writing predictions to: %s" % (output_prediction_file)) if output_prediction_file:
logger.info("Writing nbest to: %s" % (output_nbest_file)) logger.info(f"Writing predictions to: {output_prediction_file}")
if output_nbest_file:
logger.info(f"Writing nbest to: {output_nbest_file}")
if output_null_log_odds_file and version_2_with_negative:
logger.info(f"Writing null_log_odds to: {output_null_log_odds_file}")
example_index_to_features = collections.defaultdict(list) example_index_to_features = collections.defaultdict(list)
for feature in all_features: for feature in all_features:
...@@ -554,13 +558,15 @@ def compute_predictions_logits( ...@@ -554,13 +558,15 @@ def compute_predictions_logits(
all_predictions[example.qas_id] = best_non_null_entry.text all_predictions[example.qas_id] = best_non_null_entry.text
all_nbest_json[example.qas_id] = nbest_json all_nbest_json[example.qas_id] = nbest_json
if output_prediction_file:
with open(output_prediction_file, "w") as writer: with open(output_prediction_file, "w") as writer:
writer.write(json.dumps(all_predictions, indent=4) + "\n") writer.write(json.dumps(all_predictions, indent=4) + "\n")
if output_nbest_file:
with open(output_nbest_file, "w") as writer: with open(output_nbest_file, "w") as writer:
writer.write(json.dumps(all_nbest_json, indent=4) + "\n") writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
if version_2_with_negative: if output_null_log_odds_file and version_2_with_negative:
with open(output_null_log_odds_file, "w") as writer: with open(output_null_log_odds_file, "w") as writer:
writer.write(json.dumps(scores_diff_json, indent=4) + "\n") writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
......
...@@ -251,6 +251,7 @@ def squad_convert_example_to_features(example, max_seq_length, doc_stride, max_q ...@@ -251,6 +251,7 @@ def squad_convert_example_to_features(example, max_seq_length, doc_stride, max_q
start_position=start_position, start_position=start_position,
end_position=end_position, end_position=end_position,
is_impossible=span_is_impossible, is_impossible=span_is_impossible,
qas_id=example.qas_id,
) )
) )
return features return features
...@@ -344,9 +345,9 @@ def squad_convert_examples_to_features( ...@@ -344,9 +345,9 @@ def squad_convert_examples_to_features(
all_is_impossible = torch.tensor([f.is_impossible for f in features], dtype=torch.float) all_is_impossible = torch.tensor([f.is_impossible for f in features], dtype=torch.float)
if not is_training: if not is_training:
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long) all_feature_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
dataset = TensorDataset( dataset = TensorDataset(
all_input_ids, all_attention_masks, all_token_type_ids, all_example_index, all_cls_index, all_p_mask all_input_ids, all_attention_masks, all_token_type_ids, all_feature_index, all_cls_index, all_p_mask
) )
else: else:
all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long) all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long)
...@@ -368,12 +369,14 @@ def squad_convert_examples_to_features( ...@@ -368,12 +369,14 @@ def squad_convert_examples_to_features(
raise RuntimeError("TensorFlow must be installed to return a TensorFlow dataset.") raise RuntimeError("TensorFlow must be installed to return a TensorFlow dataset.")
def gen(): def gen():
for ex in features: for i, ex in enumerate(features):
yield ( yield (
{ {
"input_ids": ex.input_ids, "input_ids": ex.input_ids,
"attention_mask": ex.attention_mask, "attention_mask": ex.attention_mask,
"token_type_ids": ex.token_type_ids, "token_type_ids": ex.token_type_ids,
"feature_index": i,
"qas_id": ex.qas_id,
}, },
{ {
"start_position": ex.start_position, "start_position": ex.start_position,
...@@ -384,10 +387,15 @@ def squad_convert_examples_to_features( ...@@ -384,10 +387,15 @@ def squad_convert_examples_to_features(
}, },
) )
return tf.data.Dataset.from_generator( # Why have we split the batch into a tuple? PyTorch just has a list of tensors.
gen, train_types = (
( {
{"input_ids": tf.int32, "attention_mask": tf.int32, "token_type_ids": tf.int32}, "input_ids": tf.int32,
"attention_mask": tf.int32,
"token_type_ids": tf.int32,
"feature_index": tf.int64,
"qas_id": tf.string,
},
{ {
"start_position": tf.int64, "start_position": tf.int64,
"end_position": tf.int64, "end_position": tf.int64,
...@@ -395,12 +403,15 @@ def squad_convert_examples_to_features( ...@@ -395,12 +403,15 @@ def squad_convert_examples_to_features(
"p_mask": tf.int32, "p_mask": tf.int32,
"is_impossible": tf.int32, "is_impossible": tf.int32,
}, },
), )
(
train_shapes = (
{ {
"input_ids": tf.TensorShape([None]), "input_ids": tf.TensorShape([None]),
"attention_mask": tf.TensorShape([None]), "attention_mask": tf.TensorShape([None]),
"token_type_ids": tf.TensorShape([None]), "token_type_ids": tf.TensorShape([None]),
"feature_index": tf.TensorShape([]),
"qas_id": tf.TensorShape([]),
}, },
{ {
"start_position": tf.TensorShape([]), "start_position": tf.TensorShape([]),
...@@ -409,9 +420,10 @@ def squad_convert_examples_to_features( ...@@ -409,9 +420,10 @@ def squad_convert_examples_to_features(
"p_mask": tf.TensorShape([None]), "p_mask": tf.TensorShape([None]),
"is_impossible": tf.TensorShape([]), "is_impossible": tf.TensorShape([]),
}, },
),
) )
return tf.data.Dataset.from_generator(gen, train_types, train_shapes)
else:
return features return features
...@@ -678,6 +690,7 @@ class SquadFeatures(object): ...@@ -678,6 +690,7 @@ class SquadFeatures(object):
start_position, start_position,
end_position, end_position,
is_impossible, is_impossible,
qas_id: str = None,
): ):
self.input_ids = input_ids self.input_ids = input_ids
self.attention_mask = attention_mask self.attention_mask = attention_mask
...@@ -695,6 +708,7 @@ class SquadFeatures(object): ...@@ -695,6 +708,7 @@ class SquadFeatures(object):
self.start_position = start_position self.start_position = start_position
self.end_position = end_position self.end_position = end_position
self.is_impossible = is_impossible self.is_impossible = is_impossible
self.qas_id = qas_id
class SquadResult(object): class SquadResult(object):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment