Unverified Commit c79b550d authored by Jared T Nielsen's avatar Jared T Nielsen Committed by GitHub
Browse files

Add `qas_id` to SquadResult and SquadExample (#3745)

* Add qas_id

* Fix incorrect name in squad.py

* Make output files optional for squad eval
parent c4158a63
......@@ -307,7 +307,7 @@ def evaluate(args, model, tokenizer, prefix=""):
if args.model_type in ["xlm", "roberta", "distilbert", "camembert"]:
del inputs["token_type_ids"]
example_indices = batch[3]
feature_indices = batch[3]
# XLNet and XLM use more arguments for their predictions
if args.model_type in ["xlnet", "xlm"]:
......@@ -320,8 +320,9 @@ def evaluate(args, model, tokenizer, prefix=""):
outputs = model(**inputs)
for i, example_index in enumerate(example_indices):
eval_feature = features[example_index.item()]
for i, feature_index in enumerate(feature_indices):
# TODO: i and feature_index are the same number! Simplify by removing enumerate?
eval_feature = features[feature_index.item()]
unique_id = int(eval_feature.unique_id)
output = [to_list(output[i]) for output in outputs]
......
......@@ -384,8 +384,12 @@ def compute_predictions_logits(
tokenizer,
):
"""Write final predictions to the json file and log-odds of null if needed."""
logger.info("Writing predictions to: %s" % (output_prediction_file))
logger.info("Writing nbest to: %s" % (output_nbest_file))
if output_prediction_file:
logger.info(f"Writing predictions to: {output_prediction_file}")
if output_nbest_file:
logger.info(f"Writing nbest to: {output_nbest_file}")
if output_null_log_odds_file and version_2_with_negative:
logger.info(f"Writing null_log_odds to: {output_null_log_odds_file}")
example_index_to_features = collections.defaultdict(list)
for feature in all_features:
......@@ -554,13 +558,15 @@ def compute_predictions_logits(
all_predictions[example.qas_id] = best_non_null_entry.text
all_nbest_json[example.qas_id] = nbest_json
if output_prediction_file:
with open(output_prediction_file, "w") as writer:
writer.write(json.dumps(all_predictions, indent=4) + "\n")
if output_nbest_file:
with open(output_nbest_file, "w") as writer:
writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
if version_2_with_negative:
if output_null_log_odds_file and version_2_with_negative:
with open(output_null_log_odds_file, "w") as writer:
writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
......
......@@ -251,6 +251,7 @@ def squad_convert_example_to_features(example, max_seq_length, doc_stride, max_q
start_position=start_position,
end_position=end_position,
is_impossible=span_is_impossible,
qas_id=example.qas_id,
)
)
return features
......@@ -344,9 +345,9 @@ def squad_convert_examples_to_features(
all_is_impossible = torch.tensor([f.is_impossible for f in features], dtype=torch.float)
if not is_training:
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
all_feature_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
dataset = TensorDataset(
all_input_ids, all_attention_masks, all_token_type_ids, all_example_index, all_cls_index, all_p_mask
all_input_ids, all_attention_masks, all_token_type_ids, all_feature_index, all_cls_index, all_p_mask
)
else:
all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long)
......@@ -368,12 +369,14 @@ def squad_convert_examples_to_features(
raise RuntimeError("TensorFlow must be installed to return a TensorFlow dataset.")
def gen():
for ex in features:
for i, ex in enumerate(features):
yield (
{
"input_ids": ex.input_ids,
"attention_mask": ex.attention_mask,
"token_type_ids": ex.token_type_ids,
"feature_index": i,
"qas_id": ex.qas_id,
},
{
"start_position": ex.start_position,
......@@ -384,10 +387,15 @@ def squad_convert_examples_to_features(
},
)
return tf.data.Dataset.from_generator(
gen,
(
{"input_ids": tf.int32, "attention_mask": tf.int32, "token_type_ids": tf.int32},
# Why have we split the batch into a tuple? PyTorch just has a list of tensors.
train_types = (
{
"input_ids": tf.int32,
"attention_mask": tf.int32,
"token_type_ids": tf.int32,
"feature_index": tf.int64,
"qas_id": tf.string,
},
{
"start_position": tf.int64,
"end_position": tf.int64,
......@@ -395,12 +403,15 @@ def squad_convert_examples_to_features(
"p_mask": tf.int32,
"is_impossible": tf.int32,
},
),
(
)
train_shapes = (
{
"input_ids": tf.TensorShape([None]),
"attention_mask": tf.TensorShape([None]),
"token_type_ids": tf.TensorShape([None]),
"feature_index": tf.TensorShape([]),
"qas_id": tf.TensorShape([]),
},
{
"start_position": tf.TensorShape([]),
......@@ -409,9 +420,10 @@ def squad_convert_examples_to_features(
"p_mask": tf.TensorShape([None]),
"is_impossible": tf.TensorShape([]),
},
),
)
return tf.data.Dataset.from_generator(gen, train_types, train_shapes)
else:
return features
......@@ -678,6 +690,7 @@ class SquadFeatures(object):
start_position,
end_position,
is_impossible,
qas_id: str = None,
):
self.input_ids = input_ids
self.attention_mask = attention_mask
......@@ -695,6 +708,7 @@ class SquadFeatures(object):
self.start_position = start_position
self.end_position = end_position
self.is_impossible = is_impossible
self.qas_id = qas_id
class SquadResult(object):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment